aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSIPB2024-12-08 01:58:46 -0500
committerSIPB2024-12-08 01:58:46 -0500
commit1ea92f93e0dd5a8254fd3b0df5cc84ccdb4e505b (patch)
treede49b21ef6c3a8985e8ba8e3f73da625df15641c
parent77ad008124f9825ac80c7328923827a2212d600b (diff)
More work on blog.md, move images to img folder
-rw-r--r--blog.md181
-rw-r--r--fine-tuning-2d-histogram.pngbin8484 -> 0 bytes
-rw-r--r--fine-tuning-loss.pngbin15999 -> 0 bytes
-rw-r--r--finetune.dot9
-rw-r--r--grokking.pngbin20549 -> 0 bytes
-rw-r--r--img/embeddings.svg37
-rw-r--r--img/finetune.svg50
-rw-r--r--img/test.svg127
-rw-r--r--img/train.svg127
-rw-r--r--index.html447
-rw-r--r--insane-shortest-paths.ipynb451
-rw-r--r--loss8388
-rw-r--r--model.pthbin422298 -> 1233035 bytes
-rw-r--r--test-2d-histogram.pngbin10783 -> 0 bytes
-rw-r--r--test.dot18
-rw-r--r--train.dot18
-rw-r--r--training-2d-histogram.pngbin12615 -> 0 bytes
-rw-r--r--training-loss.pngbin20425 -> 0 bytes
-rw-r--r--transformer_shortest_paths.ipynb3603
19 files changed, 12318 insertions, 1138 deletions
diff --git a/blog.md b/blog.md
index c599f21..10333f9 100644
--- a/blog.md
+++ b/blog.md
@@ -1,7 +1,7 @@
---
build: pandoc blog.md --citeproc --katex -s -o index.html
mkzip: zip project.zip index.html *.png
-title: "6.7960 Project: Investigating Off-Distribution Generalization of Transformers"
+title: "Discovering Graph Algorithms Using Transformers"
bibliography: blog.bib
link-citations: true
---
@@ -14,25 +14,15 @@ Anthony Wang, Alek Westover, Kevin Zhao
{xy,alekw,kevinmz}\@mit.edu
</div>
-## Goals
+## Motivation
-Recently, LLMs have been developing very fast, and with that comes the concern of aligning the models to output true and productive statements. One common approach for ensuring this is to have a human in the loop rewarding the model for true outputs (e.g. RLHF), but one drawback to this problem is that humans can be poor judges of truthfulness. As LLMs become more capable, there might not even exist experts that are good judges of whether the model's outputs, such as difficult mathematical proofs, are truthful. So, we'd like to propose a potential solution to this issue via **off-distribution generalization** - applying human-like intuition to solve problems not in the dataset. Paul Christiano [proposed an experiment](https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX) about shortest paths in a graph; our project is essentially to implement Christiano's proposed experiment. To the best of our knowledge, although there has been research in applying machine learning for different variations of graph searches [@10.5555/3666122.3666260], no one has done our exact experiment yet.
+Transformers--the architecture that powers LLMs--can do incredible feats: trained on hundreds of gigabytes of raw text, they can learn to hold natural conversations, reason about the physical world, and write code. Skeptics argue that LLMs are simply memorizing their datasets without gaining any deeper understanding. For instance, GPT's o1 model, achieving 90th percentile on Codeforces, struggles with simple but bizarre algorithms problems such as "find the subarray of a 2D array with the minimum average". In this project, we hope to explore **when off-distribution generalization happens in a transformer**. Paul Christiano proposed an experiment [here](https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX) about shortest paths in a graph to investigate this, so we decided to become the first to implement his experiment and put transformers' generalization abilities to the test.
-It is generally desirable for LLMs to output true statements. A current approach for ensuring this is to have a human in the loop rewarding the model for true outputs (e.g. RLHF); however, humans can be poor judges of truthfulness. We enjoy many cognitive biases and might employ superficial heuristics when judging truthfulness. A further challenge is that as LLMs develop further, there might not even exist experts that can correctly judge the accuracy and truthfulness of sophisticated outputs such as difficult mathematical proofs.
+LLMs are notorious for making up complete nonsense, so we also hope that our project can shed light on when truthfulness generalizes. It's generally desirable for LLMs to output true statements. One current approach for ensuring this is to have a human in the loop rewarding the model for true outputs (e.g. RLHF). However, humans can be poor judges of truthfulness and have many cognitive biases and superficial heuristics. A further challenge is that as LLMs become more capable, there might not even exist experts that are good judges of whether the models outputs, such as difficult mathematical proofs, are truthful. For instance, most Task Rabbit workers would probably be hard pressed to evaluate whether a difficult mathematical proof produced by an LLM is true. The entire mathematical community has been known on occasion to [believe false statements for many years](https://en.wikipedia.org/wiki/Grunwald%E2%80%93Wang_theorem).
-One approach to solving this problem is to reward an LLM for truthful behavior on simple inputs, and then hoping that the LLM generalizes its truthful behavior for more complex inputs where humans cannot provide helpful labels. Deep learning models often perform remarkable feats of off-distribution generalization -- for instance, a model trained to transform hand drawn cats into images of cats might be able to handle a "cat" with three eyes in an intuitive way. We might hope that generalizing truthfully is simple, thus promoted by "Occam's Razor", and aim to investigate that with this project.
+One possible solution is to reward an LLM for truthful behavior on simple inputs, and then hope that the LLM generalizes its truthful behavior for more complex inputs where humans cannot provide helpful labels. Deep learning models can be remarkably good at off-distribution generalization--for instance, a model trained to transform hand drawn cats into images of cats might be able to handle a "cat" with three eyes in an intuitive way. We might hope that generalizing truthfully is simple, thus promoted by "Occam's Razor".
-COMMENT FROM KEVIN -- synthesize from intorduction
-
-### Task
-
-We will use a synthetic task to test our hypothesis that models will generalize truthfully off-distribution. The synthetic task is computing the distance between various vertices in an input graph. Our experiment will have three parts:
-
-1. Pre-train a transformer to predict the distance between two fixed vertices $s,t$ on graphs with $n\in [8, 32)$ vertices.
-2. Fine-tune a transformer to predict the distances between $s,t'$ for any $t'$ which is on the shortest path from $s$ to $t$, but only do fine-tuning on graphs with $n\in [8,16)$ vertices.
-3. Test whether the transformer can accurately predict the distances between $s,t'$ for any $t'$ on the shortest path from $s$ to $t$ for graphs with $n\in [16,32)$ vertices.
-
-### Related Work
+## Related Work
COMMENT FROM ALEK
-- please remove all mentions of graph neural networks -- that is BS: there is no actual reason why you'd ever use a Neural network to solve shortest paths, the point of choosing a synthetic task is because there is a **simple ground truth** which makes it easy to evaluate whether or not our model is performing correctly. We'd also hoped that the simplicity of the task would make it more feasible to do with a limited compute budget, but apparently this task was too hard for our architecture.
@@ -44,82 +34,58 @@ There has been some research into the algorithmic optimization of GNNs and how t
- Tutsoy uses a graph-theory-based approach to model the epidemiological characteristics of infectious diseases, such as COVID-19 [@10.1109/TPAMI.2023.3256421]. We understand from his paper how GNN optimization may also be useful in researching novel diseases.
-## Methods
-
-### Algorithm for Shortest Paths
-
-The standard algorithm to find the shortest path in a graph between a source numbered as $u$ and sink numbered as $v$ is **breadth-first search (BFS)**. The BFS algorithm maintains a mapping of visited vertices to their distances with respect to $u$, and each run of the algorithm goes through all the vertices newly visited in the previous run, and for each vertex, visits any of its unvisited neighbors. The algorithm terminates once either $v$ is visited or the set of newly visited vertices in a single run is empty.
-
-We will use this algorithm to verify the accuracy of our machine learning approach. Given $V$ vertices and $E$ edges, the runtime of this algorithm is thus $O(V + E)$; however, a machine learning approach may do better in time through parallelism, although at the expense of using much more memory.
-
-### Data
-
-We will represent an $n$ vertex, $m$ edge unweighted, undirected graph as sequence of the endpoints of the $m$ edges, so $[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]$ represents a graph with the edges $\{(a_i,b_i)\}$ for $1 \leq i \leq m$. We will pad all sequences to be the same length using the padding token 0.
-
-The full input to our model will additionally add the target vertex after the padding tokens. The model is tasked with predicting the length of the shortest path between vertex 1 and the target vertex $t$. If no such path exists, we define the length to be $n+1$ which represents infinity. For example, an input-output pair for our model could look like $[1, 3, 3, 2, 0, 0, 0, 0, 2]$ and $2$ respectively.
-
-We have three separate datasets.
-
-- **Pre-train data**: For each $n \in [8,32)$, we will generate several graphs on $n$ vertices. We generate these graphs by inserting $2n$ random edges into the graph. We always set the target vertex to be $2$ here.
-- **Fine-tune data**: For each $n \in [8,16)$, we will generate several graphs on $n$ vertices. We generate these graphs by inserting $2n$ random edges into the graph. We select the target vertex to be a random vertex on the shortest path from $1$ to $2$.
-- **Generalization testing data**: The same as the fine-tune data, except we sample $n \in [16,32)$ instead.
-
-We wrote some Python code to generate the data during the training loop, but Python is slow and the data generation wasted a lot of time during training. To get around this, we pre-generated the data before training and made our Python code multithreaded to speed it up.
-
-### Architecture
-
-TODO: honestly not much to say here since it's a pretty typical arch
-
-We plan to use a standard transformer architecture. We will ensure that the number of layers in our transformer is at least the diameter of the graph. By doing this, we ensure that there is an extremely simple circuit --- namely BFS --- that the transformer could in theory learn to perform the task. Note that if the transformer actually learns a simple circuit to perform this task, then it seems more likely to generalize well. This is also our intuition for why it should be possible to fine tune on a small amount of data for finding shortest paths to other vertices besides $2$ -- it seems like the model should be computing these other distances as intermediate values in its computation to find the distance to vertex $2$.
-
-### Embeddings
-
-Since the order of the edges in the input does not matter, we did not use positional encodings. Each edge $(u,v)$ where $u < v$ is embedded to a dimension of $d$ where the first $\frac{d}{2}$ elements are the learned embedding of $u$ and the last $\frac{d}{2}$ elements are the learned embedding of $v$. For the target vertex $t$, we also embedded to dimension $d$, where the first $\frac{d}{2}$ elements are the learned embedding of $t$ and the last $\frac{d}{2}$ are a learned embedding of a special token.
-
-## Training
-
-For our model, we used a model dimension of 64, four layers, and two heads per layer, for a total of 200545 parameters in bfloat16 which corresponds to around 3.2e6 bits. The number of possible graphs on 15 vertices generated using our procedure is approximately
-$$\frac{\binom{15}{2}^{15}}{15!} = 1.59\cdot10^{18}.$$
-
-This is because there are $\binom{15}{2}$ choices for each of the 15 edges and we don't care about the order of the edges. This is only an approximation because some edges might be duplicated. Each graph has an answer between 1 and 15 which requires around 4 bits, so memorizing all the answers requires $4\cdot1.59\cdot10^{18} = 6.36\cdot10^{18}$, which is $2\cdot10^{12}$ times larger than our model size.
+## Task
+Our synthetic task is simple: compute the distance between various vertices in an input graph. To test off-distribution generalization, our experiment has three steps.
-We used MSE loss, the Adam optimizer, a learning rate of 8e-4, and a batch size of 131072 for 8000 unique randomly generated batches. Our final MSE loss was approximately 0.3555.
+1. **Pre-train** a transformer to predict the distance between vertices $1$ and $2$ in graphs with $n \in [3,15]$ vertices.
-![](training-loss.png)
+<div style="text-align:center">
+![](img/train.svg)
+</div>
-![](training-2d-histogram.png)
+2. **Fine-tune** a transformer to predict the distances between vertex $1$ to $t$ for any $t$ on the shortest path from $1$ to $2$, but only do fine-tuning on graphs with $n \in [3,7]$ vertices.
-One pattern we consistently noticed during training is that the model often gets stuck and plateaus for many epochs before rapidly decreasing. For instance, this happened between epochs 100 and 300 in the graph above:
+<div style="text-align:center">
+![](img/finetune.svg)
+</div>
-![](grokking.png)
+3. **Test** whether the transformer can accurately predict the distances between $1$ to $t$ for any $t \leq 7$ on the shortest path from $1$ to $2$ for graphs with $n \in [3,15]$ vertices.
-"grokking" hypothesis: it's memorizing all length 2 paths?
+<div style="text-align:center">
+![](img/test.svg)
+</div>
-TODO: training curves for 1, 2, 3 length paths
+### Algorithm for Shortest Paths
-### Potential Mathematical Approaches to Shortest Paths? Delete this?
+The standard algorithm to find the shortest path in a graph between vertices $u$ and $v$ is **breadth-first search (BFS)**, taught in every intro algorithms class. Initially, BFS starts at $u$, and at each phase, explores a farther layer of vertices from $u$. During a phase, BFS goes through every vertex in the current layer and adds any of their unvisited neighbors to the next layer. The algorithm terminates once we reach $v$ or if the next layer is empty. For a graph with $V$ vertices and $E$ edges, the runtime of BFS is $O(V + E)$. BFS gives us an easy and fast way to find the ground truth answer for any graph, so that we can verify the accuracy of our machine learning approach.
-Another way one can think of the shortest path of a graph is using a *matrix* to record which vertices are connected. Given vertices numbered $1$ to $V$, we denote the **adjacency matrix** $\textbf{M}$ of dimensions $V \times V$ as the matrix with element $\textbf{M}_{i, j} = 1$ if vertices $i$ and $j$ are connected by an edge and $\textbf{M}_{i, j} = 0$ if they are not. Now, we note that (1) For all $k$, $(\textbf{M}+I)^k_{i, j} = 0$ if and only if there exists no path from the vertex numbered $i$ to the vertex numbered $j$ that is distance $k$ or less due to Markov matrix processes. As a result, if the distance between vertices numbered $i$ and $j$ is $d$, then $\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) = 1$ if $k \ge d$ and $\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) = 0$ if $k < d$.
+We hope that our model can learn BFS or some other simple, generalizable algorithm for shortest paths, because the model can't just pull some magic number out of a hat but intuitively needs to count or list the vertices on the shortest path from $1$ to $2$. In fact, we will show how to hand-craft a set of weights to implement BFS in a transformer, so it's indeed theoretically possible for a transformer to achieve 100% accuracy.
-With this information, because the distance between any two vertices is at most $V-1$ in a graph with $V$ vertices, we note that the *distance* matrix turns out to be simply $$\textbf{D} = \textbf{1}_{V \times V} \cdot V - \Sigma_{i=0}^{V-1}\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right).$$ The runtime to compute this is $O(V)$, although it will take more space to compute all powers of $\textbf{M}$.
+### Data
-## Fine tuning results
+We'll represent an $n$ vertex, $m$ edge unweighted, undirected graph as sequence of the endpoints of the $m$ edges, so $[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]$ represents a graph with the edges $\{(a_i,b_i)\}$ where $a_i < b_i$ for $1 \leq i \leq m$. All sequences are padded to the same length using the padding token $0$.
-After receiving our initial results, we fine-tuned with a learning rate of 1e-5, also with MSE and the same batch size. Our final results are shown in the images below.
+The full input to our model additionally includes the target vertex $t$ after the padding tokens. The label to an input is the length of the shortest path from $1$ to $t$. If no such path exists, we define the length to be $n+1$ which represents infinity. For example, the input $[1, 3, 3, 4, 2, 4, 2, 3, 0, 0, 0, 0, 3]$ has the label $2$.
-![](fine-tuning-loss.png)
+<div style="text-align:center">
+![](img/finetune.svg)
+</div>
-![](fine-tuning-2d-histogram.png)
+We have three datasets for each step.
-![](test-2d-histogram.png)
+1. **Pre-train data**: For each $n \in [3,15]$, we generated an equal number of graphs on $n$ vertices, with $t = 2$. Each graph was created by choosing $n$ random edges.
+2. **Fine-tune data**: For each $n \in [3,7]$, we generated an equal number of graphs on $n$ vertices each with a random $t$ on the shortest path from $1$ to $2$. Again, each graph was created by choosing $n$ random edges.
+3. **Generalization test data**: The same as the fine-tune data, except we sample $n \in [3,15]$ and $t \leq 7$.
-Memorization? Do some math here to compute how many bits required to memorize 1, 2, 3
+We wrote some Python code to generate the data during the training loop, but Python is excruciatingly slow and data generation wasted a lot of training time. Our solution was to pre-generate the data before training using a multithreaded version of our Python code.
## Complicated explicit transformer formula for shortest paths
+TODO: Kevin or Alek
+
```py
# Configuration
NVTXS = 16
@@ -221,9 +187,7 @@ class SillyTransformer(nn.Module):
return final_output
```
-It looked like the fine tuning results weren't as meaningful because TODO: ALEK
-## Customizing a Transformer
After much deliberation, we decided the next step for us was to customize a transformer, writing it ourselves. We observed that we wished for the transformer to do similar work as a BFS. As a result, we decided to work with the following transformer, for a graph with $n$ vertices $v_1, v_2, \cdots, v_n$:
@@ -241,8 +205,81 @@ Specifically, we see that $\text{NBR}_{i}$ is a $n \times 1$ vector detailing wh
This iterates through all the vertices, and at the end, we may see what run we are on to update $V_{\text{OUT}}$.
+## Perturbing the Weights
+
+SLT folks like to think about geometry of loss landscape CITE
+So we did an experiment where we mess with the weights.
+
+Findings: XXX
+
+
+## Our Model
+
+### Architecture
+
+We used a standard transformer architecture. To ensure that it can in theory learn BFS, we ensured that the number of layers in our transformer exceeds the diameter of the input graphs.
+
+Since the order of the edges in the input doesn't matter, we did not use positional encodings. Each edge $(a,b)$ is embedded to dimension $d$ where the first $\frac{d}{2}$ elements are the learned embedding of $a$ and the last $\frac{d}{2}$ elements are the learned embedding of $b$. For the target vertex $t$, we pair it with the special token $TARGET$ and embed $(t,TARGET)$ in the same way.
+
+<!-- https://cocreate.csail.mit.edu/r/sxArTEXiAgJshznmm -->
+![](img/embeddings.svg)
+
+### Training
+
+To match the BFS transformer as closely as possible, we used a model dimension of $64$, $11$ layers, and $2$ heads per layer, for a total of 550433 parameters. In 32-bit float precision, that corresponds to around $1.76\cdot10^6$ bits. The number of possible graphs on 15 vertices generated using our procedure is approximately
+
+$$\frac{\binom{15}{2}^{15}}{15!} = 1.59\cdot10^{18}.$$
+
+This is because there are $\binom{15}{2}$ choices for each of the 15 edges and we don't care about the order of the edges. This is only an approximation because some edges might be duplicated. Each graph has an answer between 1 and 15 which requires around 4 bits, so memorizing all the answers requires $4\cdot1.59\cdot10^{18} = 6.36\cdot10^{18}$ bits, which is $3.61\cdot10^{12}$ times larger than our model size.
+
+To train the model, we used MSE loss, the Adam optimizer, a learning rate of $3\cdot10^{-4}$, and a batch size of $2^{15}$ for one billion randomly generated graphs. A training run takes roughly eight hours to run on a Radeon 7900 XTX graphics card. Our final MSE loss was $0.000555$.
+
+TODO: use https://mpld3.github.io/index.html to make interactive plots
+
+![](training-loss.png)
+
+![](training-2d-histogram.png)
+
+One pattern we consistently noticed during training is that the model often gets stuck and plateaus for many epochs before rapidly decreasing. For instance, this happened between epochs 100 and 300 in the graph above:
+
+![](grokking.png)
+
+"grokking" hypothesis: it's memorizing all length 2 paths?
+
+TODO: cite Neel Nanda grokking modular addition
+
+TODO: CRAZY!!! training curves for 1, 2, 3 length paths
+
+One pitfall we encountered during training is that we initially used bfloat16 to save VRAM, but our huge batch size caused loss-of-precision problems and made training very difficult. It took us two weeks to debug this until we found that switching to float32 improved training significantly.
+
+## Fine tuning results
+
+After receiving our initial results, we fine-tuned with a learning rate of 1e-5, also with MSE and the same batch size. Our final results are shown in the images below.
+
+![](fine-tuning-loss.png)
+
+![](fine-tuning-2d-histogram.png)
+
+![](test-2d-histogram.png)
+
+TODO: get new graphs
+
+It's pretty good!!!
+
+Can only generalize to target vertices from 2 to 7 since 8 through 15 didn't appear in the fine-tune data
+
+but this still means it
+
## Conclusion
+however, a machine learning approach may do better in time through parallelism, although at the expense of using much more memory.
+**TODO: ALEK: this is BS. If you want a parallel algorithm for BFS, here's one https://en.wikipedia.org/wiki/Parallel_single-source_shortest_path_algorithm**
+
just do bfs lol
+**Future Work**
+There are a couple of other things that we could try to learn shortest paths better and maybe see more generalization.
+- Chain of thought
+- Train model to output a path, not just the distance. Give it partial points for outputting anything that is a legitimate path (i.e., consists of real edges) and more points for getting the distance correct.
+
## References
diff --git a/fine-tuning-2d-histogram.png b/fine-tuning-2d-histogram.png
deleted file mode 100644
index df45973..0000000
--- a/fine-tuning-2d-histogram.png
+++ /dev/null
Binary files differ
diff --git a/fine-tuning-loss.png b/fine-tuning-loss.png
deleted file mode 100644
index 7e95a06..0000000
--- a/fine-tuning-loss.png
+++ /dev/null
Binary files differ
diff --git a/finetune.dot b/finetune.dot
new file mode 100644
index 0000000..f80b54b
--- /dev/null
+++ b/finetune.dot
@@ -0,0 +1,9 @@
+graph G {
+ 1 [color=blue];
+ 2 [color=red];
+ 3 [color=green];
+ 1 -- 3 [color=red];
+ 4 -- 3;
+ 2 -- 4;
+ 2 -- 3 [color=red];
+}
diff --git a/grokking.png b/grokking.png
deleted file mode 100644
index 848b4ab..0000000
--- a/grokking.png
+++ /dev/null
Binary files differ
diff --git a/img/embeddings.svg b/img/embeddings.svg
new file mode 100644
index 0000000..6e7b510
--- /dev/null
+++ b/img/embeddings.svg
@@ -0,0 +1,37 @@
+<?xml version="1.0" encoding="utf-8"?>
+<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="451.1200256347656 187.3000030517578 532.1399841308594 229.5599822998047" width="532.1399841308594px" height="229.5599822998047px">
+<style>
+.pen line { stroke-linecap: round }
+.pen polyline { stroke-linecap: round; stroke-linejoin: round; fill: none }
+.grid { stroke-width: 0.96; stroke: #c4e3f4 }
+text { font-family: "Roboto Slab", serif }
+@font-face { font-family: "Roboto Slab"; src: url("https://fonts.gstatic.com/s/robotoslab/v34/BngMUXZYTXPIvIBgJJSb6ufN5qU.woff2") format("woff2"); font-style: normal; font-weight: 500; font-display: swap; unicode-range: U+0-FF, U+131, U+152-153, U+2BB-2BC, U+2C6, U+2DA, U+2DC, U+304, U+308, U+329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; }
+</style>
+<g class="grid"/>
+<rect data-id="6xjpYJRwx2i9Lk4ka" x="566.4" y="188.79999999999998" width="37.75999999999999" height="226.55999999999997" stroke="black" stroke-width="3" stroke-linejoin="round" fill="none"/>
+<rect data-id="zoZEjjCquWhaApnri" x="717.4399999999999" y="188.79999999999998" width="37.75999999999999" height="226.55999999999997" stroke="black" stroke-width="3" stroke-linejoin="round" fill="none"/>
+<ellipse data-id="RXkH59ujgaxLJaxT8" cx="792.9599999999999" cy="302.08" rx="6" ry="6" stroke="black" stroke-width="3" fill="black"/>
+<ellipse data-id="uAWPFosgmQNTg8N4k" cx="830.7199999999999" cy="302.08" rx="6" ry="6" stroke="black" stroke-width="3" fill="black"/>
+<ellipse data-id="nPcPPyrfRZWh6J6Ya" cx="868.4799999999999" cy="302.08" rx="6" ry="6" stroke="black" stroke-width="3" fill="black"/>
+<polyline data-id="ktor56FqpMui8qxAj" points="566.4,188.79999999999998 566.4,339.84" stroke="black" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none"/>
+<g data-id="gFW86zFLkDdze6ozK" transform="translate(-77.52 -10)"><g transform="translate(528.64,264.32)"><text fill="black" style="font-size:19px">embed(<tspan data-tex="a_1" data-display="false" dx="21.712345"> </tspan>)<tspan class="cursor"> </tspan></text><g transform="translate(67.5666649978447 -10.153268999999998) scale(19)"><svg xmlns="http://www.w3.org/2000/svg" width="1.142755" height="0.699251" role="img" focusable="false" viewBox="0 -441 965.6 591" xmlns:xlink="http://www.w3.org/1999/xlink"><defs><path id="MJX-1-TEX-I-1D44E" d="M33 157Q33 258 109 349T280 441Q331 441 370 392Q386 422 416 422Q429 422 439 414T449 394Q449 381 412 234T374 68Q374 43 381 35T402 26Q411 27 422 35Q443 55 463 131Q469 151 473 152Q475 153 483 153H487Q506 153 506 144Q506 138 501 117T481 63T449 13Q436 0 417 -8Q409 -10 393 -10Q359 -10 336 5T306 36L300 51Q299 52 296 50Q294 48 292 46Q233 -10 172 -10Q117 -10 75 30T33 157ZM351 328Q351 334 346 350T323 385T277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q217 26 254 59T298 110Q300 114 325 217T351 328Z"/><path id="MJX-1-TEX-N-31" d="M213 578L200 573Q186 568 160 563T102 556H83V602H102Q149 604 189 617T245 641T273 663Q275 666 285 666Q294 666 302 660V361L303 61Q310 54 315 52T339 48T401 46H427V0H416Q395 3 257 3Q121 3 100 0H88V46H114Q136 46 152 46T177 47T193 50T201 52T207 57T213 61V578Z"/></defs><g stroke="black" fill="black" stroke-width="0" transform="scale(1,-1)"><g><g><g><use xlink:href="#MJX-1-TEX-I-1D44E"/></g><g transform="translate(562,-150) scale(0.707)"><use xlink:href="#MJX-1-TEX-N-31"/></g></g></g></g></svg></g></g></g>
+<polyline data-id="grMWHhWxJc6KRc6nL" points="537,231 555,196" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none"/>
+<polyline data-id="yRWMuguDhZb3YbGfe" points="537,264 554,297" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none"/>
+<g data-id="m6RKrpB3nnArJNhjM" transform="translate(-76.52 101)"><g transform="translate(528.64,264.32)"><text fill="black" style="font-size:19px">embed(<tspan class="cursor"> </tspan><tspan data-tex="b_1" data-display="false" dx="19.456646000000003"> </tspan>)</text><g transform="translate(69.46666636600875 -15.847169999999998) scale(19)"><svg xmlns="http://www.w3.org/2000/svg" width="1.024034" height="0.99893" role="img" focusable="false" viewBox="0 -694 865.6 844" xmlns:xlink="http://www.w3.org/1999/xlink"><defs><path id="MJX-3-TEX-I-1D44F" d="M73 647Q73 657 77 670T89 683Q90 683 161 688T234 694Q246 694 246 685T212 542Q204 508 195 472T180 418L176 399Q176 396 182 402Q231 442 283 442Q345 442 383 396T422 280Q422 169 343 79T173 -11Q123 -11 82 27T40 150V159Q40 180 48 217T97 414Q147 611 147 623T109 637Q104 637 101 637H96Q86 637 83 637T76 640T73 647ZM336 325V331Q336 405 275 405Q258 405 240 397T207 376T181 352T163 330L157 322L136 236Q114 150 114 114Q114 66 138 42Q154 26 178 26Q211 26 245 58Q270 81 285 114T318 219Q336 291 336 325Z"/><path id="MJX-3-TEX-N-31" d="M213 578L200 573Q186 568 160 563T102 556H83V602H102Q149 604 189 617T245 641T273 663Q275 666 285 666Q294 666 302 660V361L303 61Q310 54 315 52T339 48T401 46H427V0H416Q395 3 257 3Q121 3 100 0H88V46H114Q136 46 152 46T177 47T193 50T201 52T207 57T213 61V578Z"/></defs><g stroke="black" fill="black" stroke-width="0" transform="scale(1,-1)"><g><g><g><use xlink:href="#MJX-3-TEX-I-1D44F"/></g><g transform="translate(462,-150) scale(0.707)"><use xlink:href="#MJX-3-TEX-N-31"/></g></g></g></g></svg></g></g></g>
+<polyline data-id="55e5b8jmmajQmwjbz" points="537,231 555,196" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none" transform="translate(1 111)"/>
+<polyline data-id="4gwSf4FdeHvAjEEbp" points="537,264 554,297" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none" transform="translate(1 111)"/>
+<g data-id="qTJ276E2nmCYzv8MQ" transform="translate(79.48 -13)"><g transform="translate(528.64,264.32)"><text fill="black" style="font-size:19px">embed(<tspan class="cursor"> </tspan><tspan data-tex="a_2" data-display="false" dx="21.712345"> </tspan>)</text><g transform="translate(69.4666665237236 -10.153268999999998) scale(19)"><svg xmlns="http://www.w3.org/2000/svg" width="1.142755" height="0.699251" role="img" focusable="false" viewBox="0 -441 965.6 591" xmlns:xlink="http://www.w3.org/1999/xlink"><defs><path id="MJX-5-TEX-I-1D44E" d="M33 157Q33 258 109 349T280 441Q331 441 370 392Q386 422 416 422Q429 422 439 414T449 394Q449 381 412 234T374 68Q374 43 381 35T402 26Q411 27 422 35Q443 55 463 131Q469 151 473 152Q475 153 483 153H487Q506 153 506 144Q506 138 501 117T481 63T449 13Q436 0 417 -8Q409 -10 393 -10Q359 -10 336 5T306 36L300 51Q299 52 296 50Q294 48 292 46Q233 -10 172 -10Q117 -10 75 30T33 157ZM351 328Q351 334 346 350T323 385T277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q217 26 254 59T298 110Q300 114 325 217T351 328Z"/><path id="MJX-5-TEX-N-32" d="M109 429Q82 429 66 447T50 491Q50 562 103 614T235 666Q326 666 387 610T449 465Q449 422 429 383T381 315T301 241Q265 210 201 149L142 93L218 92Q375 92 385 97Q392 99 409 186V189H449V186Q448 183 436 95T421 3V0H50V19V31Q50 38 56 46T86 81Q115 113 136 137Q145 147 170 174T204 211T233 244T261 278T284 308T305 340T320 369T333 401T340 431T343 464Q343 527 309 573T212 619Q179 619 154 602T119 569T109 550Q109 549 114 549Q132 549 151 535T170 489Q170 464 154 447T109 429Z"/></defs><g stroke="black" fill="black" stroke-width="0" transform="scale(1,-1)"><g><g><g><use xlink:href="#MJX-5-TEX-I-1D44E"/></g><g transform="translate(562,-150) scale(0.707)"><use xlink:href="#MJX-5-TEX-N-32"/></g></g></g></g></svg></g></g></g>
+<polyline data-id="akW8gHYnjZMKtBCum" points="537,231 555,196" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none" transform="translate(157 -3)"/>
+<polyline data-id="BfPaaFDaNi5Hh4Qpu" points="537,264 554,297" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none" transform="translate(157 -3)"/>
+<g data-id="GBFNd9TYiKbRf4yQm" transform="translate(78.48 101)"><g transform="translate(528.64,264.32)"><text fill="black" style="font-size:19px">embed(<tspan class="cursor"> </tspan><tspan data-tex="b_2" data-display="false" dx="19.456646000000003"> </tspan>)</text><g transform="translate(69.46666636600875 -15.847169999999998) scale(19)"><svg xmlns="http://www.w3.org/2000/svg" width="1.024034" height="0.99893" role="img" focusable="false" viewBox="0 -694 865.6 844" xmlns:xlink="http://www.w3.org/1999/xlink"><defs><path id="MJX-8-TEX-I-1D44F" d="M73 647Q73 657 77 670T89 683Q90 683 161 688T234 694Q246 694 246 685T212 542Q204 508 195 472T180 418L176 399Q176 396 182 402Q231 442 283 442Q345 442 383 396T422 280Q422 169 343 79T173 -11Q123 -11 82 27T40 150V159Q40 180 48 217T97 414Q147 611 147 623T109 637Q104 637 101 637H96Q86 637 83 637T76 640T73 647ZM336 325V331Q336 405 275 405Q258 405 240 397T207 376T181 352T163 330L157 322L136 236Q114 150 114 114Q114 66 138 42Q154 26 178 26Q211 26 245 58Q270 81 285 114T318 219Q336 291 336 325Z"/><path id="MJX-8-TEX-N-32" d="M109 429Q82 429 66 447T50 491Q50 562 103 614T235 666Q326 666 387 610T449 465Q449 422 429 383T381 315T301 241Q265 210 201 149L142 93L218 92Q375 92 385 97Q392 99 409 186V189H449V186Q448 183 436 95T421 3V0H50V19V31Q50 38 56 46T86 81Q115 113 136 137Q145 147 170 174T204 211T233 244T261 278T284 308T305 340T320 369T333 401T340 431T343 464Q343 527 309 573T212 619Q179 619 154 602T119 569T109 550Q109 549 114 549Q132 549 151 535T170 489Q170 464 154 447T109 429Z"/></defs><g stroke="black" fill="black" stroke-width="0" transform="scale(1,-1)"><g><g><g><use xlink:href="#MJX-8-TEX-I-1D44F"/></g><g transform="translate(462,-150) scale(0.707)"><use xlink:href="#MJX-8-TEX-N-32"/></g></g></g></g></svg></g></g></g>
+<polyline data-id="2bC4FvcDaRi2B74qq" points="537,231 555,196" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none" transform="translate(156 111)"/>
+<polyline data-id="MRbTmv3wMm5PRRJ7r" points="537,264 554,297" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none" transform="translate(156 111)"/>
+<g data-id="CXkSyTPCNLuG6JBbq" transform="translate(230.96000000000004 101)"><g transform="translate(528.64,264.32)"><text fill="black" style="font-size:19px">embed(<tspan class="cursor"> </tspan><tspan data-tex="TARGET
+" data-display="false" dx="100.423322"> </tspan>)</text><g transform="translate(69.46666517069626 -16.20779) scale(19)"><svg xmlns="http://www.w3.org/2000/svg" width="5.285438" height="0.87341" role="img" focusable="false" viewBox="0 -716 4467 738" xmlns:xlink="http://www.w3.org/1999/xlink"><defs><path id="MJX-21-TEX-I-1D447" d="M40 437Q21 437 21 445Q21 450 37 501T71 602L88 651Q93 669 101 677H569H659Q691 677 697 676T704 667Q704 661 687 553T668 444Q668 437 649 437Q640 437 637 437T631 442L629 445Q629 451 635 490T641 551Q641 586 628 604T573 629Q568 630 515 631Q469 631 457 630T439 622Q438 621 368 343T298 60Q298 48 386 46Q418 46 427 45T436 36Q436 31 433 22Q429 4 424 1L422 0Q419 0 415 0Q410 0 363 1T228 2Q99 2 64 0H49Q43 6 43 9T45 27Q49 40 55 46H83H94Q174 46 189 55Q190 56 191 56Q196 59 201 76T241 233Q258 301 269 344Q339 619 339 625Q339 630 310 630H279Q212 630 191 624Q146 614 121 583T67 467Q60 445 57 441T43 437H40Z"/><path id="MJX-21-TEX-I-1D434" d="M208 74Q208 50 254 46Q272 46 272 35Q272 34 270 22Q267 8 264 4T251 0Q249 0 239 0T205 1T141 2Q70 2 50 0H42Q35 7 35 11Q37 38 48 46H62Q132 49 164 96Q170 102 345 401T523 704Q530 716 547 716H555H572Q578 707 578 706L606 383Q634 60 636 57Q641 46 701 46Q726 46 726 36Q726 34 723 22Q720 7 718 4T704 0Q701 0 690 0T651 1T578 2Q484 2 455 0H443Q437 6 437 9T439 27Q443 40 445 43L449 46H469Q523 49 533 63L521 213H283L249 155Q208 86 208 74ZM516 260Q516 271 504 416T490 562L463 519Q447 492 400 412L310 260L413 259Q516 259 516 260Z"/><path id="MJX-21-TEX-I-1D445" d="M230 637Q203 637 198 638T193 649Q193 676 204 682Q206 683 378 683Q550 682 564 680Q620 672 658 652T712 606T733 563T739 529Q739 484 710 445T643 385T576 351T538 338L545 333Q612 295 612 223Q612 212 607 162T602 80V71Q602 53 603 43T614 25T640 16Q668 16 686 38T712 85Q717 99 720 102T735 105Q755 105 755 93Q755 75 731 36Q693 -21 641 -21H632Q571 -21 531 4T487 82Q487 109 502 166T517 239Q517 290 474 313Q459 320 449 321T378 323H309L277 193Q244 61 244 59Q244 55 245 54T252 50T269 48T302 46H333Q339 38 339 37T336 19Q332 6 326 0H311Q275 2 180 2Q146 2 117 2T71 2T50 1Q33 1 33 10Q33 12 36 24Q41 43 46 45Q50 46 61 46H67Q94 46 127 49Q141 52 146 61Q149 65 218 339T287 628Q287 635 230 637ZM630 554Q630 586 609 608T523 636Q521 636 500 636T462 637H440Q393 637 386 627Q385 624 352 494T319 361Q319 360 388 360Q466 361 492 367Q556 377 592 426Q608 449 619 486T630 554Z"/><path id="MJX-21-TEX-I-1D43A" d="M50 252Q50 367 117 473T286 641T490 704Q580 704 633 653Q642 643 648 636T656 626L657 623Q660 623 684 649Q691 655 699 663T715 679T725 690L740 705H746Q760 705 760 698Q760 694 728 561Q692 422 692 421Q690 416 687 415T669 413H653Q647 419 647 422Q647 423 648 429T650 449T651 481Q651 552 619 605T510 659Q492 659 471 656T418 643T357 615T294 567T236 496T189 394T158 260Q156 242 156 221Q156 173 170 136T206 79T256 45T308 28T353 24Q407 24 452 47T514 106Q517 114 529 161T541 214Q541 222 528 224T468 227H431Q425 233 425 235T427 254Q431 267 437 273H454Q494 271 594 271Q634 271 659 271T695 272T707 272Q721 272 721 263Q721 261 719 249Q714 230 709 228Q706 227 694 227Q674 227 653 224Q646 221 643 215T629 164Q620 131 614 108Q589 6 586 3Q584 1 581 1Q571 1 553 21T530 52Q530 53 528 52T522 47Q448 -22 322 -22Q201 -22 126 55T50 252Z"/><path id="MJX-21-TEX-I-1D438" d="M492 213Q472 213 472 226Q472 230 477 250T482 285Q482 316 461 323T364 330H312Q311 328 277 192T243 52Q243 48 254 48T334 46Q428 46 458 48T518 61Q567 77 599 117T670 248Q680 270 683 272Q690 274 698 274Q718 274 718 261Q613 7 608 2Q605 0 322 0H133Q31 0 31 11Q31 13 34 25Q38 41 42 43T65 46Q92 46 125 49Q139 52 144 61Q146 66 215 342T285 622Q285 629 281 629Q273 632 228 634H197Q191 640 191 642T193 659Q197 676 203 680H757Q764 676 764 669Q764 664 751 557T737 447Q735 440 717 440H705Q698 445 698 453L701 476Q704 500 704 528Q704 558 697 578T678 609T643 625T596 632T532 634H485Q397 633 392 631Q388 629 386 622Q385 619 355 499T324 377Q347 376 372 376H398Q464 376 489 391T534 472Q538 488 540 490T557 493Q562 493 565 493T570 492T572 491T574 487T577 483L544 351Q511 218 508 216Q505 213 492 213Z"/></defs><g stroke="black" fill="black" stroke-width="0" transform="scale(1,-1)"><g><g><use xlink:href="#MJX-21-TEX-I-1D447"/></g><g transform="translate(704,0)"><use xlink:href="#MJX-21-TEX-I-1D434"/></g><g transform="translate(1454,0)"><use xlink:href="#MJX-21-TEX-I-1D445"/></g><g transform="translate(2213,0)"><use xlink:href="#MJX-21-TEX-I-1D43A"/></g><g transform="translate(2999,0)"><use xlink:href="#MJX-21-TEX-I-1D438"/></g><g transform="translate(3763,0)"><use xlink:href="#MJX-21-TEX-I-1D447"/></g></g></g></svg></g></g></g>
+<polyline data-id="72aTqe7sJtijaAs8Z" points="537,231 555,196" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none" transform="translate(384 111)"/>
+<polyline data-id="m5AzXmESQfrdz76RT" points="537,264 554,297" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none" transform="translate(384 111)"/>
+<g data-id="7XBesfT4DwF8NSrGK" transform="translate(305.48 -10)"><g transform="translate(528.64,264.32)"><text fill="black" style="font-size:19px">embed(<tspan class="cursor"> </tspan><tspan data-tex="t" data-display="false" dx="8.118528999999999"> </tspan>)</text><g transform="translate(69.46666385917283 -14.169717000000002) scale(19)"><svg xmlns="http://www.w3.org/2000/svg" width="0.427291" height="0.7536430000000001" role="img" focusable="false" viewBox="0 -626 361 637" xmlns:xlink="http://www.w3.org/1999/xlink"><defs><path id="MJX-11-TEX-I-1D461" d="M26 385Q19 392 19 395Q19 399 22 411T27 425Q29 430 36 430T87 431H140L159 511Q162 522 166 540T173 566T179 586T187 603T197 615T211 624T229 626Q247 625 254 615T261 596Q261 589 252 549T232 470L222 433Q222 431 272 431H323Q330 424 330 420Q330 398 317 385H210L174 240Q135 80 135 68Q135 26 162 26Q197 26 230 60T283 144Q285 150 288 151T303 153H307Q322 153 322 145Q322 142 319 133Q314 117 301 95T267 48T216 6T155 -11Q125 -11 98 4T59 56Q57 64 57 83V101L92 241Q127 382 128 383Q128 385 77 385H26Z"/></defs><g stroke="black" fill="black" stroke-width="0" transform="scale(1,-1)"><g><g><use xlink:href="#MJX-11-TEX-I-1D461"/></g></g></g></svg></g></g></g>
+<polyline data-id="ys5kFA6ZzwwpupiiG" points="537,231 555,196" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none" transform="translate(383 0)"/>
+<polyline data-id="RYdTiRDL6xCKBhafN" points="537,264 554,297" stroke="#ff0000" stroke-width="3" stroke-linecap="round" stroke-linejoin="round" fill="none" transform="translate(383 0)"/>
+<rect data-id="PrqNJQctFTGkQ5vX5" x="944" y="188.79999999999998" width="37.75999999999999" height="226.55999999999997" stroke="black" stroke-width="3" stroke-linejoin="round" fill="none"/>
+</svg> \ No newline at end of file
diff --git a/img/finetune.svg b/img/finetune.svg
new file mode 100644
index 0000000..aeac772
--- /dev/null
+++ b/img/finetune.svg
@@ -0,0 +1,50 @@
+<svg width="172pt" height="188pt" viewBox="0.00 0.00 172.00 188.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 184)">
+<title>G</title>
+<!-- <polygon fill="white" stroke="none" points="-4,4 -4,-184 168,-184 168,4 -4,4"/> -->
+<!-- 1 -->
+<g id="node1" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="blue" cx="27" cy="-90" rx="27" ry="18"/>
+<text text-anchor="middle" x="27" y="-85.8" font-family="Times,serif" font-size="14.00">1</text>
+</g>
+<!-- 3 -->
+<g id="node3" class="node">
+<title>3</title>
+<ellipse fill="none" stroke="green" cx="82" cy="-18" rx="27" ry="18"/>
+<text text-anchor="middle" x="82" y="-13.8" font-family="Times,serif" font-size="14.00">3</text>
+</g>
+<!-- 1&#45;&#45;3 -->
+<g id="edge1" class="edge">
+<title>1--3</title>
+<path fill="none" stroke="red" d="M39.21,-73.46C48.32,-61.86 60.72,-46.08 69.83,-34.49"/>
+</g>
+<!-- 2 -->
+<g id="node2" class="node">
+<title>2</title>
+<ellipse fill="none" stroke="red" cx="109" cy="-162" rx="27" ry="18"/>
+<text text-anchor="middle" x="109" y="-157.8" font-family="Times,serif" font-size="14.00">2</text>
+</g>
+<!-- 2&#45;&#45;3 -->
+<g id="edge4" class="edge">
+<title>2--3</title>
+<path fill="none" stroke="red" d="M105.74,-143.87C100.55,-116.58 90.47,-63.52 85.27,-36.19"/>
+</g>
+<!-- 4 -->
+<g id="node4" class="node">
+<title>4</title>
+<ellipse fill="none" stroke="black" cx="137" cy="-90" rx="27" ry="18"/>
+<text text-anchor="middle" x="137" y="-85.8" font-family="Times,serif" font-size="14.00">4</text>
+</g>
+<!-- 2&#45;&#45;4 -->
+<g id="edge3" class="edge">
+<title>2--4</title>
+<path fill="none" stroke="black" d="M115.64,-144.41C120.1,-133.25 125.96,-118.6 130.41,-107.47"/>
+</g>
+<!-- 4&#45;&#45;3 -->
+<g id="edge2" class="edge">
+<title>4--3</title>
+<path fill="none" stroke="black" d="M124.79,-73.46C115.68,-61.86 103.28,-46.08 94.17,-34.49"/>
+</g>
+</g>
+</svg>
diff --git a/img/test.svg b/img/test.svg
new file mode 100644
index 0000000..2d03d82
--- /dev/null
+++ b/img/test.svg
@@ -0,0 +1,127 @@
+<svg width="316pt" height="260pt" viewBox="0.00 0.00 316.00 260.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 256)">
+<title>G</title>
+<!-- <polygon fill="white" stroke="none" points="-4,4 -4,-256 312,-256 312,4 -4,4"/> -->
+<!-- 1 -->
+<g id="node1" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="blue" cx="27" cy="-162" rx="27" ry="18"/>
+<text text-anchor="middle" x="27" y="-157.8" font-family="Times,serif" font-size="14.00">1</text>
+</g>
+<!-- 3 -->
+<g id="node3" class="node">
+<title>3</title>
+<ellipse fill="none" stroke="red" cx="82" cy="-90" rx="27" ry="18"/>
+<text text-anchor="middle" x="82" y="-85.8" font-family="Times,serif" font-size="14.00">3</text>
+</g>
+<!-- 1&#45;&#45;3 -->
+<g id="edge1" class="edge">
+<title>1--3</title>
+<path fill="none" stroke="red" d="M39.21,-145.46C48.32,-133.86 60.72,-118.08 69.83,-106.49"/>
+</g>
+<!-- 2 -->
+<g id="node2" class="node">
+<title>2</title>
+<ellipse fill="none" stroke="red" cx="226" cy="-18" rx="27" ry="18"/>
+<text text-anchor="middle" x="226" y="-13.8" font-family="Times,serif" font-size="14.00">2</text>
+</g>
+<!-- 8 -->
+<g id="node11" class="node">
+<title>8</title>
+<ellipse fill="none" stroke="black" cx="118" cy="-18" rx="27" ry="18"/>
+<text text-anchor="middle" x="118" y="-13.8" font-family="Times,serif" font-size="14.00">8</text>
+</g>
+<!-- 3&#45;&#45;8 -->
+<g id="edge9" class="edge">
+<title>3--8</title>
+<path fill="none" stroke="black" d="M90.35,-72.76C96.17,-61.46 103.89,-46.44 109.7,-35.15"/>
+</g>
+<!-- 5 -->
+<g id="node4" class="node">
+<title>5</title>
+<ellipse fill="none" stroke="green" cx="209" cy="-162" rx="27" ry="18"/>
+<text text-anchor="middle" x="209" y="-157.8" font-family="Times,serif" font-size="14.00">5</text>
+</g>
+<!-- 5&#45;&#45;3 -->
+<g id="edge8" class="edge">
+<title>5--3</title>
+<path fill="none" stroke="red" d="M188.6,-149.75C164.9,-136.69 126.03,-115.27 102.35,-102.22"/>
+</g>
+<!-- 11 -->
+<g id="node5" class="node">
+<title>11</title>
+<ellipse fill="none" stroke="red" cx="226" cy="-90" rx="27" ry="18"/>
+<text text-anchor="middle" x="226" y="-85.8" font-family="Times,serif" font-size="14.00">11</text>
+</g>
+<!-- 5&#45;&#45;11 -->
+<g id="edge10" class="edge">
+<title>5--11</title>
+<path fill="none" stroke="red" d="M213.12,-144.05C215.75,-133.21 219.16,-119.18 221.8,-108.28"/>
+</g>
+<!-- 11&#45;&#45;2 -->
+<g id="edge3" class="edge">
+<title>11--2</title>
+<path fill="none" stroke="red" d="M226,-71.7C226,-60.85 226,-46.92 226,-36.1"/>
+</g>
+<!-- 4 -->
+<g id="node6" class="node">
+<title>4</title>
+<ellipse fill="none" stroke="black" cx="137" cy="-234" rx="27" ry="18"/>
+<text text-anchor="middle" x="137" y="-229.8" font-family="Times,serif" font-size="14.00">4</text>
+</g>
+<!-- 4&#45;&#45;3 -->
+<g id="edge2" class="edge">
+<title>4--3</title>
+<path fill="none" stroke="black" d="M124.25,-217.7C116.44,-207.59 106.85,-193.71 101,-180 90.89,-156.31 85.95,-126.71 83.7,-108.33"/>
+</g>
+<!-- 4&#45;&#45;5 -->
+<g id="edge6" class="edge">
+<title>4--5</title>
+<path fill="none" stroke="black" d="M151.92,-218.5C164.28,-206.48 181.75,-189.49 194.1,-177.49"/>
+</g>
+<!-- 9 -->
+<g id="node7" class="node">
+<title>9</title>
+<ellipse fill="none" stroke="black" cx="137" cy="-162" rx="27" ry="18"/>
+<text text-anchor="middle" x="137" y="-157.8" font-family="Times,serif" font-size="14.00">9</text>
+</g>
+<!-- 4&#45;&#45;9 -->
+<g id="edge11" class="edge">
+<title>4--9</title>
+<path fill="none" stroke="black" d="M137,-215.7C137,-204.85 137,-190.92 137,-180.1"/>
+</g>
+<!-- 9&#45;&#45;3 -->
+<g id="edge4" class="edge">
+<title>9--3</title>
+<path fill="none" stroke="black" d="M124.79,-145.46C115.68,-133.86 103.28,-118.08 94.17,-106.49"/>
+</g>
+<!-- 6 -->
+<g id="node8" class="node">
+<title>6</title>
+<ellipse fill="none" stroke="black" cx="281" cy="-234" rx="27" ry="18"/>
+<text text-anchor="middle" x="281" y="-229.8" font-family="Times,serif" font-size="14.00">6</text>
+</g>
+<!-- 7 -->
+<g id="node9" class="node">
+<title>7</title>
+<ellipse fill="none" stroke="black" cx="281" cy="-162" rx="27" ry="18"/>
+<text text-anchor="middle" x="281" y="-157.8" font-family="Times,serif" font-size="14.00">7</text>
+</g>
+<!-- 6&#45;&#45;7 -->
+<g id="edge5" class="edge">
+<title>6--7</title>
+<path fill="none" stroke="black" d="M281,-215.7C281,-204.85 281,-190.92 281,-180.1"/>
+</g>
+<!-- 10 -->
+<g id="node10" class="node">
+<title>10</title>
+<ellipse fill="none" stroke="black" cx="154" cy="-90" rx="27" ry="18"/>
+<text text-anchor="middle" x="154" y="-85.8" font-family="Times,serif" font-size="14.00">10</text>
+</g>
+<!-- 10&#45;&#45;8 -->
+<g id="edge7" class="edge">
+<title>10--8</title>
+<path fill="none" stroke="black" d="M145.65,-72.76C139.83,-61.46 132.11,-46.44 126.3,-35.15"/>
+</g>
+</g>
+</svg>
diff --git a/img/train.svg b/img/train.svg
new file mode 100644
index 0000000..ef63aac
--- /dev/null
+++ b/img/train.svg
@@ -0,0 +1,127 @@
+<svg width="316pt" height="260pt" viewBox="0.00 0.00 316.00 260.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 256)">
+<title>G</title>
+<!-- <polygon fill="white" stroke="none" points="-4,4 -4,-256 312,-256 312,4 -4,4"/> -->
+<!-- 1 -->
+<g id="node1" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="blue" cx="27" cy="-162" rx="27" ry="18"/>
+<text text-anchor="middle" x="27" y="-157.8" font-family="Times,serif" font-size="14.00">1</text>
+</g>
+<!-- 3 -->
+<g id="node3" class="node">
+<title>3</title>
+<ellipse fill="none" stroke="red" cx="82" cy="-90" rx="27" ry="18"/>
+<text text-anchor="middle" x="82" y="-85.8" font-family="Times,serif" font-size="14.00">3</text>
+</g>
+<!-- 1&#45;&#45;3 -->
+<g id="edge1" class="edge">
+<title>1--3</title>
+<path fill="none" stroke="red" d="M39.21,-145.46C48.32,-133.86 60.72,-118.08 69.83,-106.49"/>
+</g>
+<!-- 2 -->
+<g id="node2" class="node">
+<title>2</title>
+<ellipse fill="none" stroke="green" cx="226" cy="-18" rx="27" ry="18"/>
+<text text-anchor="middle" x="226" y="-13.8" font-family="Times,serif" font-size="14.00">2</text>
+</g>
+<!-- 8 -->
+<g id="node11" class="node">
+<title>8</title>
+<ellipse fill="none" stroke="black" cx="118" cy="-18" rx="27" ry="18"/>
+<text text-anchor="middle" x="118" y="-13.8" font-family="Times,serif" font-size="14.00">8</text>
+</g>
+<!-- 3&#45;&#45;8 -->
+<g id="edge9" class="edge">
+<title>3--8</title>
+<path fill="none" stroke="black" d="M90.35,-72.76C96.17,-61.46 103.89,-46.44 109.7,-35.15"/>
+</g>
+<!-- 5 -->
+<g id="node4" class="node">
+<title>5</title>
+<ellipse fill="none" stroke="red" cx="209" cy="-162" rx="27" ry="18"/>
+<text text-anchor="middle" x="209" y="-157.8" font-family="Times,serif" font-size="14.00">5</text>
+</g>
+<!-- 5&#45;&#45;3 -->
+<g id="edge8" class="edge">
+<title>5--3</title>
+<path fill="none" stroke="red" d="M188.6,-149.75C164.9,-136.69 126.03,-115.27 102.35,-102.22"/>
+</g>
+<!-- 11 -->
+<g id="node5" class="node">
+<title>11</title>
+<ellipse fill="none" stroke="red" cx="226" cy="-90" rx="27" ry="18"/>
+<text text-anchor="middle" x="226" y="-85.8" font-family="Times,serif" font-size="14.00">11</text>
+</g>
+<!-- 5&#45;&#45;11 -->
+<g id="edge10" class="edge">
+<title>5--11</title>
+<path fill="none" stroke="red" d="M213.12,-144.05C215.75,-133.21 219.16,-119.18 221.8,-108.28"/>
+</g>
+<!-- 11&#45;&#45;2 -->
+<g id="edge3" class="edge">
+<title>11--2</title>
+<path fill="none" stroke="red" d="M226,-71.7C226,-60.85 226,-46.92 226,-36.1"/>
+</g>
+<!-- 4 -->
+<g id="node6" class="node">
+<title>4</title>
+<ellipse fill="none" stroke="black" cx="137" cy="-234" rx="27" ry="18"/>
+<text text-anchor="middle" x="137" y="-229.8" font-family="Times,serif" font-size="14.00">4</text>
+</g>
+<!-- 4&#45;&#45;3 -->
+<g id="edge2" class="edge">
+<title>4--3</title>
+<path fill="none" stroke="black" d="M124.25,-217.7C116.44,-207.59 106.85,-193.71 101,-180 90.89,-156.31 85.95,-126.71 83.7,-108.33"/>
+</g>
+<!-- 4&#45;&#45;5 -->
+<g id="edge6" class="edge">
+<title>4--5</title>
+<path fill="none" stroke="black" d="M151.92,-218.5C164.28,-206.48 181.75,-189.49 194.1,-177.49"/>
+</g>
+<!-- 9 -->
+<g id="node7" class="node">
+<title>9</title>
+<ellipse fill="none" stroke="black" cx="137" cy="-162" rx="27" ry="18"/>
+<text text-anchor="middle" x="137" y="-157.8" font-family="Times,serif" font-size="14.00">9</text>
+</g>
+<!-- 4&#45;&#45;9 -->
+<g id="edge11" class="edge">
+<title>4--9</title>
+<path fill="none" stroke="black" d="M137,-215.7C137,-204.85 137,-190.92 137,-180.1"/>
+</g>
+<!-- 9&#45;&#45;3 -->
+<g id="edge4" class="edge">
+<title>9--3</title>
+<path fill="none" stroke="black" d="M124.79,-145.46C115.68,-133.86 103.28,-118.08 94.17,-106.49"/>
+</g>
+<!-- 6 -->
+<g id="node8" class="node">
+<title>6</title>
+<ellipse fill="none" stroke="black" cx="281" cy="-234" rx="27" ry="18"/>
+<text text-anchor="middle" x="281" y="-229.8" font-family="Times,serif" font-size="14.00">6</text>
+</g>
+<!-- 7 -->
+<g id="node9" class="node">
+<title>7</title>
+<ellipse fill="none" stroke="black" cx="281" cy="-162" rx="27" ry="18"/>
+<text text-anchor="middle" x="281" y="-157.8" font-family="Times,serif" font-size="14.00">7</text>
+</g>
+<!-- 6&#45;&#45;7 -->
+<g id="edge5" class="edge">
+<title>6--7</title>
+<path fill="none" stroke="black" d="M281,-215.7C281,-204.85 281,-190.92 281,-180.1"/>
+</g>
+<!-- 10 -->
+<g id="node10" class="node">
+<title>10</title>
+<ellipse fill="none" stroke="black" cx="154" cy="-90" rx="27" ry="18"/>
+<text text-anchor="middle" x="154" y="-85.8" font-family="Times,serif" font-size="14.00">10</text>
+</g>
+<!-- 10&#45;&#45;8 -->
+<g id="edge7" class="edge">
+<title>10--8</title>
+<path fill="none" stroke="black" d="M145.65,-72.76C139.83,-61.46 132.11,-46.44 126.3,-35.15"/>
+</g>
+</g>
+</svg>
diff --git a/index.html b/index.html
index efec318..fd5f74f 100644
--- a/index.html
+++ b/index.html
@@ -4,7 +4,7 @@
<meta charset="utf-8" />
<meta name="generator" content="pandoc" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" />
- <title>6.7960 Project: Investigating Off-Distribution Generalization of Transformers</title>
+ <title>Discovering Graph Algorithms Using Transformers</title>
<style>
html {
color: #1a1a1a;
@@ -271,74 +271,51 @@
</head>
<body>
<header id="title-block-header">
-<h1 class="title">6.7960 Project: Investigating Off-Distribution
-Generalization of Transformers</h1>
+<h1 class="title">Discovering Graph Algorithms Using Transformers</h1>
</header>
<!-- Guidelines: https://www.dropbox.com/scl/fi/bet8enscln8ue36kd8t17/final_project_guidelines.pdf?rlkey=knd19cnumk51ho1y9crno56ib&e=2&dl=0 -->
<div style="text-align:center">
<p>Anthony Wang, Alek Westover, Kevin Zhao</p>
<p>{xy,alekw,kevinmz}@mit.edu</p>
</div>
-<h2 id="goals">Goals</h2>
-<p>Recently, LLMs have been developing very fast, and with that comes
-the concern of aligning the models to output true and productive
-statements. One common approach for ensuring this is to have a human in
-the loop rewarding the model for true outputs (e.g. RLHF), but one
-drawback to this problem is that humans can be poor judges of
-truthfulness. As LLMs become more capable, there might not even exist
-experts that are good judges of whether the model’s outputs, such as
-difficult mathematical proofs, are truthful. So, we’d like to propose a
-potential solution to this issue via <strong>off-distribution
-generalization</strong> - applying human-like intuition to solve
-problems not in the dataset. Paul Christiano <a
-href="https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX">proposed
-an experiment</a> about shortest paths in a graph; our project is
-essentially to implement Christiano’s proposed experiment. To the best
-of our knowledge, although there has been research in applying machine
-learning for different variations of graph searches <span
-class="citation" data-cites="10.5555/3666122.3666260">(<a
-href="#ref-10.5555/3666122.3666260" role="doc-biblioref">Zang et al.
-2024</a>)</span>, no one has done our exact experiment yet.</p>
-<p>It is generally desirable for LLMs to output true statements. A
-current approach for ensuring this is to have a human in the loop
-rewarding the model for true outputs (e.g. RLHF); however, humans can be
-poor judges of truthfulness. We enjoy many cognitive biases and might
-employ superficial heuristics when judging truthfulness. A further
-challenge is that as LLMs develop further, there might not even exist
-experts that can correctly judge the accuracy and truthfulness of
-sophisticated outputs such as difficult mathematical proofs.</p>
-<p>One approach to solving this problem is to reward an LLM for truthful
-behavior on simple inputs, and then hoping that the LLM generalizes its
-truthful behavior for more complex inputs where humans cannot provide
-helpful labels. Deep learning models often perform remarkable feats of
-off-distribution generalization – for instance, a model trained to
-transform hand drawn cats into images of cats might be able to handle a
-“cat” with three eyes in an intuitive way. We might hope that
-generalizing truthfully is simple, thus promoted by “Occam’s Razor”, and
-aim to investigate that with this project.</p>
-<p>COMMENT FROM KEVIN – synthesize from intorduction</p>
-<h3 id="task">Task</h3>
-<p>We will use a synthetic task to test our hypothesis that models will
-generalize truthfully off-distribution. The synthetic task is computing
-the distance between various vertices in an input graph. Our experiment
-will have three parts:</p>
-<ol type="1">
-<li>Pre-train a transformer to predict the distance between two fixed
-vertices <span class="math inline">s,t</span> on graphs with <span
-class="math inline">n\in [8, 32)</span> vertices.</li>
-<li>Fine-tune a transformer to predict the distances between <span
-class="math inline">s,t&#39;</span> for any <span
-class="math inline">t&#39;</span> which is on the shortest path from
-<span class="math inline">s</span> to <span
-class="math inline">t</span>, but only do fine-tuning on graphs with
-<span class="math inline">n\in [8,16)</span> vertices.</li>
-<li>Test whether the transformer can accurately predict the distances
-between <span class="math inline">s,t&#39;</span> for any <span
-class="math inline">t&#39;</span> on the shortest path from <span
-class="math inline">s</span> to <span class="math inline">t</span> for
-graphs with <span class="math inline">n\in [16,32)</span> vertices.</li>
-</ol>
-<h3 id="related-work">Related Work</h3>
+<h2 id="motivation">Motivation</h2>
+<p>Transformers–the architecture that powers LLMs–can do incredible
+feats: trained on hundreds of gigabytes of raw text, they can learn to
+hold natural conversations, reason about the physical world, and write
+code. Skeptics argue that LLMs are simply memorizing their datasets
+without gaining any deeper understanding. For instance, GPT’s o1 model,
+achieving 90th percentile on Codeforces, struggles with simple but
+bizarre algorithms problems such as “find the subarray of a 2D array
+with the minimum average”. In this project, we hope to explore
+<strong>when off-distribution generalization happens in a
+transformer</strong>. Paul Christiano proposed an experiment <a
+href="https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX">here</a>
+about shortest paths in a graph to investigate this, so we decided to
+become the first to implement his experiment and put transformers’
+generalization abilities to the test.</p>
+<p>LLMs are notorious for making up complete nonsense, so we also hope
+that our project can shed light on when truthfulness generalizes. It’s
+generally desirable for LLMs to output true statements. One current
+approach for ensuring this is to have a human in the loop rewarding the
+model for true outputs (e.g. RLHF). However, humans can be poor judges
+of truthfulness and have many cognitive biases and superficial
+heuristics. A further challenge is that as LLMs become more capable,
+there might not even exist experts that are good judges of whether the
+models outputs, such as difficult mathematical proofs, are truthful. For
+instance, most Task Rabbit workers would probably be hard pressed to
+evaluate whether a difficult mathematical proof produced by an LLM is
+true. The entire mathematical community has been known on occasion to <a
+href="https://en.wikipedia.org/wiki/Grunwald%E2%80%93Wang_theorem">believe
+false statements for many years</a>.</p>
+<p>One possible solution is to reward an LLM for truthful behavior on
+simple inputs, and then hope that the LLM generalizes its truthful
+behavior for more complex inputs where humans cannot provide helpful
+labels. Deep learning models can be remarkably good at off-distribution
+generalization–for instance, a model trained to transform hand drawn
+cats into images of cats might be able to handle a “cat” with three eyes
+in an intuitive way. We might hope that generalizing truthfully is
+simple, thus promoted by “Occam’s Razor”.</p>
+<h2 id="related-work">Related Work</h2>
<p>COMMENT FROM ALEK – please remove all mentions of graph neural
networks – that is BS: there is no actual reason why you’d ever use a
Neural network to solve shortest paths, the point of choosing a
@@ -367,170 +344,113 @@ href="#ref-10.1109/TPAMI.2023.3256421" role="doc-biblioref">Tutsoy
2023</a>)</span>. We understand from his paper how GNN optimization may
also be useful in researching novel diseases.</p></li>
</ul>
-<h2 id="methods">Methods</h2>
+<h2 id="task">Task</h2>
+<p>Our synthetic task is simple: compute the distance between various
+vertices in an input graph. To test off-distribution generalization, our
+experiment has three steps.</p>
+<ol type="1">
+<li><strong>Pre-train</strong> a transformer to predict the distance
+between vertices <span class="math inline">1</span> and <span
+class="math inline">2</span> in graphs with <span class="math inline">n
+\in [3,15]</span> vertices.</li>
+</ol>
+<div style="text-align:center">
+<p><img src="img/train.svg" /></p>
+</div>
+<ol start="2" type="1">
+<li><strong>Fine-tune</strong> a transformer to predict the distances
+between vertex <span class="math inline">1</span> to <span
+class="math inline">t</span> for any <span class="math inline">t</span>
+on the shortest path from <span class="math inline">1</span> to <span
+class="math inline">2</span>, but only do fine-tuning on graphs with
+<span class="math inline">n \in [3,7]</span> vertices.</li>
+</ol>
+<div style="text-align:center">
+<p><img src="img/finetune.svg" /></p>
+</div>
+<ol start="3" type="1">
+<li><strong>Test</strong> whether the transformer can accurately predict
+the distances between <span class="math inline">1</span> to <span
+class="math inline">t</span> for any <span class="math inline">t \leq
+7</span> on the shortest path from <span class="math inline">1</span> to
+<span class="math inline">2</span> for graphs with <span
+class="math inline">n \in [3,15]</span> vertices.</li>
+</ol>
+<div style="text-align:center">
+<p><img src="img/test.svg" /></p>
+</div>
<h3 id="algorithm-for-shortest-paths">Algorithm for Shortest Paths</h3>
-<p>The standard algorithm to find the shortest path in a graph between a
-source numbered as <span class="math inline">u</span> and sink numbered
-as <span class="math inline">v</span> is <strong>breadth-first search
-(BFS)</strong>. The BFS algorithm maintains a mapping of visited
-vertices to their distances with respect to <span
-class="math inline">u</span>, and each run of the algorithm goes through
-all the vertices newly visited in the previous run, and for each vertex,
-visits any of its unvisited neighbors. The algorithm terminates once
-either <span class="math inline">v</span> is visited or the set of newly
-visited vertices in a single run is empty.</p>
-<p>We will use this algorithm to verify the accuracy of our machine
-learning approach. Given <span class="math inline">V</span> vertices and
-<span class="math inline">E</span> edges, the runtime of this algorithm
-is thus <span class="math inline">O(V + E)</span>; however, a machine
-learning approach may do better in time through parallelism, although at
-the expense of using much more memory.</p>
+<p>The standard algorithm to find the shortest path in a graph between
+vertices <span class="math inline">u</span> and <span
+class="math inline">v</span> is <strong>breadth-first search
+(BFS)</strong>, taught in every intro algorithms class. Initially, BFS
+starts at <span class="math inline">u</span>, and at each phase,
+explores a farther layer of vertices from <span
+class="math inline">u</span>. During a phase, BFS goes through every
+vertex in the current layer and adds any of their unvisited neighbors to
+the next layer. The algorithm terminates once we reach <span
+class="math inline">v</span> or if the next layer is empty. For a graph
+with <span class="math inline">V</span> vertices and <span
+class="math inline">E</span> edges, the runtime of BFS is <span
+class="math inline">O(V + E)</span>. BFS gives us an easy and fast way
+to find the ground truth answer for any graph, so that we can verify the
+accuracy of our machine learning approach.</p>
+<p>We hope that our model can learn BFS or some other simple,
+generalizable algorithm for shortest paths, because the model can’t just
+pull some magic number out of a hat but intuitively needs to count or
+list the vertices on the shortest path from <span
+class="math inline">1</span> to <span class="math inline">2</span>. In
+fact, we will show how to hand-craft a set of weights to implement BFS
+in a transformer, so it’s indeed theoretically possible for a
+transformer to achieve 100% accuracy.</p>
<h3 id="data">Data</h3>
-<p>We will represent an <span class="math inline">n</span> vertex, <span
+<p>We’ll represent an <span class="math inline">n</span> vertex, <span
class="math inline">m</span> edge unweighted, undirected graph as
sequence of the endpoints of the <span class="math inline">m</span>
edges, so <span
class="math inline">[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]</span> represents a
-graph with the edges <span class="math inline">\{(a_i,b_i)\}</span> for
-<span class="math inline">1 \leq i \leq m</span>. We will pad all
-sequences to be the same length using the padding token 0.</p>
-<p>The full input to our model will additionally add the target vertex
-after the padding tokens. The model is tasked with predicting the length
-of the shortest path between vertex 1 and the target vertex <span
-class="math inline">t</span>. If no such path exists, we define the
-length to be <span class="math inline">n+1</span> which represents
-infinity. For example, an input-output pair for our model could look
-like <span class="math inline">[1, 3, 3, 2, 0, 0, 0, 0, 2]</span> and
-<span class="math inline">2</span> respectively.</p>
-<p>We have three separate datasets.</p>
-<ul>
+graph with the edges <span class="math inline">\{(a_i,b_i)\}</span>
+where <span class="math inline">a_i &lt; b_i</span> for <span
+class="math inline">1 \leq i \leq m</span>. All sequences are padded to
+the same length using the padding token <span
+class="math inline">0</span>.</p>
+<p>The full input to our model additionally includes the target vertex
+<span class="math inline">t</span> after the padding tokens. The label
+to an input is the length of the shortest path from <span
+class="math inline">1</span> to <span class="math inline">t</span>. If
+no such path exists, we define the length to be <span
+class="math inline">n+1</span> which represents infinity. For example,
+the input <span class="math inline">[1, 3, 3, 4, 2, 4, 2, 3, 0, 0, 0, 0,
+3]</span> has the label <span class="math inline">2</span>.</p>
+<div style="text-align:center">
+<p><img src="img/finetune.svg" /></p>
+</div>
+<p>We have three datasets for each step.</p>
+<ol type="1">
<li><strong>Pre-train data</strong>: For each <span
-class="math inline">n \in [8,32)</span>, we will generate several graphs
-on <span class="math inline">n</span> vertices. We generate these graphs
-by inserting <span class="math inline">2n</span> random edges into the
-graph. We always set the target vertex to be <span
-class="math inline">2</span> here.</li>
+class="math inline">n \in [3,15]</span>, we generated an equal number of
+graphs on <span class="math inline">n</span> vertices, with <span
+class="math inline">t = 2</span>. Each graph was created by choosing
+<span class="math inline">n</span> random edges.</li>
<li><strong>Fine-tune data</strong>: For each <span
-class="math inline">n \in [8,16)</span>, we will generate several graphs
-on <span class="math inline">n</span> vertices. We generate these graphs
-by inserting <span class="math inline">2n</span> random edges into the
-graph. We select the target vertex to be a random vertex on the shortest
-path from <span class="math inline">1</span> to <span
-class="math inline">2</span>.</li>
-<li><strong>Generalization testing data</strong>: The same as the
-fine-tune data, except we sample <span class="math inline">n \in
-[16,32)</span> instead.</li>
-</ul>
+class="math inline">n \in [3,7]</span>, we generated an equal number of
+graphs on <span class="math inline">n</span> vertices each with a random
+<span class="math inline">t</span> on the shortest path from <span
+class="math inline">1</span> to <span class="math inline">2</span>.
+Again, each graph was created by choosing <span
+class="math inline">n</span> random edges.</li>
+<li><strong>Generalization test data</strong>: The same as the fine-tune
+data, except we sample <span class="math inline">n \in [3,15]</span> and
+<span class="math inline">t \leq 7</span>.</li>
+</ol>
<p>We wrote some Python code to generate the data during the training
-loop, but Python is slow and the data generation wasted a lot of time
-during training. To get around this, we pre-generated the data before
-training and made our Python code multithreaded to speed it up.</p>
-<h3 id="architecture">Architecture</h3>
-<p>TODO: honestly not much to say here since it’s a pretty typical
-arch</p>
-<p>We plan to use a standard transformer architecture. We will ensure
-that the number of layers in our transformer is at least the diameter of
-the graph. By doing this, we ensure that there is an extremely simple
-circuit — namely BFS — that the transformer could in theory learn to
-perform the task. Note that if the transformer actually learns a simple
-circuit to perform this task, then it seems more likely to generalize
-well. This is also our intuition for why it should be possible to fine
-tune on a small amount of data for finding shortest paths to other
-vertices besides <span class="math inline">2</span> – it seems like the
-model should be computing these other distances as intermediate values
-in its computation to find the distance to vertex <span
-class="math inline">2</span>.</p>
-<h3 id="embeddings">Embeddings</h3>
-<p>Since the order of the edges in the input does not matter, we did not
-use positional encodings. Each edge <span
-class="math inline">(u,v)</span> where <span class="math inline">u &lt;
-v</span> is embedded to a dimension of <span
-class="math inline">d</span> where the first <span
-class="math inline">\frac{d}{2}</span> elements are the learned
-embedding of <span class="math inline">u</span> and the last <span
-class="math inline">\frac{d}{2}</span> elements are the learned
-embedding of <span class="math inline">v</span>. For the target vertex
-<span class="math inline">t</span>, we also embedded to dimension <span
-class="math inline">d</span>, where the first <span
-class="math inline">\frac{d}{2}</span> elements are the learned
-embedding of <span class="math inline">t</span> and the last <span
-class="math inline">\frac{d}{2}</span> are a learned embedding of a
-special token.</p>
-<h2 id="training">Training</h2>
-<p>For our model, we used a model dimension of 64, four layers, and two
-heads per layer, for a total of 200545 parameters in bfloat16 which
-corresponds to around 3.2e6 bits. The number of possible graphs on 15
-vertices generated using our procedure is approximately</p>
-<p><span class="math display">\frac{\binom{15}{2}^{15}}{15!} =
-1.59\cdot10^{18}.</span></p>
-<p>This is because there are <span
-class="math inline">\binom{15}{2}</span> choices for each of the 15
-edges and we don’t care about the order of the edges. This is only an
-approximation because some edges might be duplicated. Each graph has an
-answer between 1 and 15 which requires around 4 bits, so memorizing all
-the answers requires <span class="math inline">4\cdot1.59\cdot10^{18} =
-6.36\cdot10^{18}</span>, which is <span
-class="math inline">2\cdot10^{12}</span> times larger than our model
-size.</p>
-<p>We used MSE loss, the Adam optimizer, a learning rate of 8e-4, and a
-batch size of 131072 for 8000 unique randomly generated batches. Our
-final MSE loss was approximately 0.3555.</p>
-<p><img src="training-loss.png" /></p>
-<p><img src="training-2d-histogram.png" /></p>
-<p>One pattern we consistently noticed during training is that the model
-often gets stuck and plateaus for many epochs before rapidly decreasing.
-For instance, this happened between epochs 100 and 300 in the graph
-above:</p>
-<p><img src="grokking.png" /></p>
-<p>“grokking” hypothesis: it’s memorizing all length 2 paths?</p>
-<p>TODO: training curves for 1, 2, 3 length paths</p>
-<h3
-id="potential-mathematical-approaches-to-shortest-paths-delete-this">Potential
-Mathematical Approaches to Shortest Paths? Delete this?</h3>
-<p>Another way one can think of the shortest path of a graph is using a
-<em>matrix</em> to record which vertices are connected. Given vertices
-numbered <span class="math inline">1</span> to <span
-class="math inline">V</span>, we denote the <strong>adjacency
-matrix</strong> <span class="math inline">\textbf{M}</span> of
-dimensions <span class="math inline">V \times V</span> as the matrix
-with element <span class="math inline">\textbf{M}_{i, j} = 1</span> if
-vertices <span class="math inline">i</span> and <span
-class="math inline">j</span> are connected by an edge and <span
-class="math inline">\textbf{M}_{i, j} = 0</span> if they are not. Now,
-we note that (1) For all <span class="math inline">k</span>, <span
-class="math inline">(\textbf{M}+I)^k_{i, j} = 0</span> if and only if
-there exists no path from the vertex numbered <span
-class="math inline">i</span> to the vertex numbered <span
-class="math inline">j</span> that is distance <span
-class="math inline">k</span> or less due to Markov matrix processes. As
-a result, if the distance between vertices numbered <span
-class="math inline">i</span> and <span class="math inline">j</span> is
-<span class="math inline">d</span>, then <span
-class="math inline">\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) =
-1</span> if <span class="math inline">k \ge d</span> and <span
-class="math inline">\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) =
-0</span> if <span class="math inline">k &lt; d</span>.</p>
-<p>With this information, because the distance between any two vertices
-is at most <span class="math inline">V-1</span> in a graph with <span
-class="math inline">V</span> vertices, we note that the
-<em>distance</em> matrix turns out to be simply <span
-class="math display">\textbf{D} = \textbf{1}_{V \times V} \cdot V -
-\Sigma_{i=0}^{V-1}\text{min}\left((\textbf{M}+I)^k_{i, j},
-1\right).</span> The runtime to compute this is <span
-class="math inline">O(V)</span>, although it will take more space to
-compute all powers of <span class="math inline">\textbf{M}</span>.</p>
-<h2 id="fine-tuning-results">Fine tuning results</h2>
-<p>After receiving our initial results, we fine-tuned with a learning
-rate of 1e-5, also with MSE and the same batch size. Our final results
-are shown in the images below.</p>
-<p><img src="fine-tuning-loss.png" /></p>
-<p><img src="fine-tuning-2d-histogram.png" /></p>
-<p><img src="test-2d-histogram.png" /></p>
-<p>Memorization? Do some math here to compute how many bits required to
-memorize 1, 2, 3</p>
+loop, but Python is excruciatingly slow and data generation wasted a lot
+of training time. Our solution was to pre-generate the data before
+training using a multithreaded version of our Python code.</p>
<h2
id="complicated-explicit-transformer-formula-for-shortest-paths">Complicated
explicit transformer formula for shortest paths</h2>
+<p>TODO: Kevin or Alek</p>
<div class="sourceCode" id="cb1"><pre class="sourceCode py"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Configuration</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>NVTXS <span class="op">=</span> <span class="dv">16</span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>MAXDIST <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
@@ -629,9 +549,6 @@ explicit transformer formula for shortest paths</h2>
<span id="cb1-96"><a href="#cb1-96" aria-hidden="true" tabindex="-1"></a> canreach <span class="op">=</span> src[:, <span class="dv">0</span>, START_OUT:START_OUT <span class="op">+</span> NVTXS]</span>
<span id="cb1-97"><a href="#cb1-97" aria-hidden="true" tabindex="-1"></a> final_output <span class="op">=</span> <span class="dv">1</span> <span class="op">+</span> torch.<span class="bu">sum</span>(<span class="dv">1</span> <span class="op">-</span> canreach, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb1-98"><a href="#cb1-98" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> final_output</span></code></pre></div>
-<p>It looked like the fine tuning results weren’t as meaningful because
-TODO: ALEK</p>
-<h2 id="customizing-a-transformer">Customizing a Transformer</h2>
<p>After much deliberation, we decided the next step for us was to
customize a transformer, writing it ourselves. We observed that we
wished for the transformer to do similar work as a BFS. As a result, we
@@ -686,8 +603,95 @@ class="math inline">i</span>th entry is nonzero.</p>
<p>This iterates through all the vertices, and at the end, we may see
what run we are on to update <span
class="math inline">V_{\text{OUT}}</span>.</p>
+<h2 id="perturbing-the-weights">Perturbing the Weights</h2>
+<p>SLT folks like to think about geometry of loss landscape CITE So we
+did an experiment where we mess with the weights.</p>
+<p>Findings: XXX</p>
+<h2 id="our-model">Our Model</h2>
+<h3 id="architecture">Architecture</h3>
+<p>We used a standard transformer architecture. To ensure that it can in
+theory learn BFS, we ensured that the number of layers in our
+transformer exceeds the diameter of the input graphs.</p>
+<p>Since the order of the edges in the input doesn’t matter, we did not
+use positional encodings. Each edge <span
+class="math inline">(a,b)</span> is embedded to dimension <span
+class="math inline">d</span> where the first <span
+class="math inline">\frac{d}{2}</span> elements are the learned
+embedding of <span class="math inline">a</span> and the last <span
+class="math inline">\frac{d}{2}</span> elements are the learned
+embedding of <span class="math inline">b</span>. For the target vertex
+<span class="math inline">t</span>, we pair it with the special token
+<span class="math inline">TARGET</span> and embed <span
+class="math inline">(t,TARGET)</span> in the same way.</p>
+<!-- https://cocreate.csail.mit.edu/r/sxArTEXiAgJshznmm -->
+<p><img src="img/embeddings.svg" /></p>
+<h3 id="training">Training</h3>
+<p>To match the BFS transformer as closely as possible, we used a model
+dimension of <span class="math inline">64</span>, <span
+class="math inline">11</span> layers, and <span
+class="math inline">2</span> heads per layer, for a total of 550433
+parameters. In 32-bit float precision, that corresponds to around <span
+class="math inline">1.76\cdot10^6</span> bits. The number of possible
+graphs on 15 vertices generated using our procedure is approximately</p>
+<p><span class="math display">\frac{\binom{15}{2}^{15}}{15!} =
+1.59\cdot10^{18}.</span></p>
+<p>This is because there are <span
+class="math inline">\binom{15}{2}</span> choices for each of the 15
+edges and we don’t care about the order of the edges. This is only an
+approximation because some edges might be duplicated. Each graph has an
+answer between 1 and 15 which requires around 4 bits, so memorizing all
+the answers requires <span class="math inline">4\cdot1.59\cdot10^{18} =
+6.36\cdot10^{18}</span> bits, which is <span
+class="math inline">3.61\cdot10^{12}</span> times larger than our model
+size.</p>
+<p>To train the model, we used MSE loss, the Adam optimizer, a learning
+rate of <span class="math inline">3\cdot10^{-4}</span>, and a batch size
+of <span class="math inline">2^{15}</span> for one billion randomly
+generated graphs. A training run takes roughly eight hours to run on a
+Radeon 7900 XTX graphics card. Our final MSE loss was <span
+class="math inline">0.000555</span>.</p>
+<p>TODO: use https://mpld3.github.io/index.html to make interactive
+plots</p>
+<p><img src="training-loss.png" /></p>
+<p><img src="training-2d-histogram.png" /></p>
+<p>One pattern we consistently noticed during training is that the model
+often gets stuck and plateaus for many epochs before rapidly decreasing.
+For instance, this happened between epochs 100 and 300 in the graph
+above:</p>
+<p><img src="grokking.png" /></p>
+<p>“grokking” hypothesis: it’s memorizing all length 2 paths?</p>
+<p>TODO: cite Neel Nanda grokking modular addition</p>
+<p>TODO: CRAZY!!! training curves for 1, 2, 3 length paths</p>
+<p>One pitfall we encountered during training is that we initially used
+bfloat16 to save VRAM, but our huge batch size caused loss-of-precision
+problems and made training very difficult. It took us two weeks to debug
+this until we found that switching to float32 improved training
+significantly.</p>
+<h2 id="fine-tuning-results">Fine tuning results</h2>
+<p>After receiving our initial results, we fine-tuned with a learning
+rate of 1e-5, also with MSE and the same batch size. Our final results
+are shown in the images below.</p>
+<p><img src="fine-tuning-loss.png" /></p>
+<p><img src="fine-tuning-2d-histogram.png" /></p>
+<p><img src="test-2d-histogram.png" /></p>
+<p>TODO: get new graphs</p>
+<p>It’s pretty good!!!</p>
+<p>Can only generalize to target vertices from 2 to 7 since 8 through 15
+didn’t appear in the fine-tune data</p>
+<p>but this still means it</p>
<h2 id="conclusion">Conclusion</h2>
+<p>however, a machine learning approach may do better in time through
+parallelism, although at the expense of using much more memory.
+<strong>TODO: ALEK: this is BS. If you want a parallel algorithm for
+BFS, here’s one
+https://en.wikipedia.org/wiki/Parallel_single-source_shortest_path_algorithm</strong></p>
<p>just do bfs lol</p>
+<p><strong>Future Work</strong> There are a couple of other things that
+we could try to learn shortest paths better and maybe see more
+generalization. - Chain of thought - Train model to output a path, not
+just the distance. Give it partial points for outputting anything that
+is a legitimate path (i.e., consists of real edges) and more points for
+getting the distance correct.</p>
<h2 class="unnumbered" id="references">References</h2>
<div id="refs" class="references csl-bib-body hanging-indent"
data-entry-spacing="0" role="list">
@@ -707,13 +711,6 @@ Exact Epidemiological Modeling of Pandemic Diseases.”</span> <em>IEEE
Trans. Pattern Anal. Mach. Intell.</em> 45 (8): 9836–45. <a
href="https://doi.org/10.1109/TPAMI.2023.3256421">https://doi.org/10.1109/TPAMI.2023.3256421</a>.
</div>
-<div id="ref-10.5555/3666122.3666260" class="csl-entry" role="listitem">
-Zang, Xiao, Miao Yin, Jinqi Xiao, Saman Zonouz, and Bo Yuan. 2024.
-<span>“GraphMP: Graph Neural Network-Based Motion Planning with
-Efficient Graph Search.”</span> In <em>Proceedings of the 37th
-International Conference on Neural Information Processing Systems</em>.
-NIPS ’23. Red Hook, NY, USA: Curran Associates Inc.
-</div>
</div>
</body>
</html>
diff --git a/insane-shortest-paths.ipynb b/insane-shortest-paths.ipynb
index e74974b..a3e58fb 100644
--- a/insane-shortest-paths.ipynb
+++ b/insane-shortest-paths.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 1,
"execution_state": "idle",
"id": "86ce5f44-94f6-43b0-a0d1-091b8134ffb6",
"metadata": {},
@@ -22,8 +22,8 @@
"from collections import deque\n",
"\n",
"# Set manual seeds for reproducibility\n",
- "torch.manual_seed(33)\n",
- "random.seed(33)\n",
+ "# torch.manual_seed(33)\n",
+ "# random.seed(33)\n",
"\n",
"# Configuration\n",
"NVTXS = 16\n",
@@ -122,7 +122,6 @@
" Q = nn.Parameter(torch.zeros((2, HIDDENDIM), device=device))\n",
" Q[0, START_REACH - 1 + head] = SUPABIG\n",
" Q[1, NOTANS_FLAG_IDX] = 1\n",
- "btrfs filesystem resize max\n",
" K = nn.Parameter(torch.zeros((2, HIDDENDIM), device=device))\n",
" K[0, head] = 1\n",
" K[1, ANS_FLAG_IDX] = BIG\n",
@@ -134,7 +133,6 @@
" self.mostKs.append(K)\n",
" self.mostQs.append(Q)\n",
" self.mostVs.append(V)\n",
- "\n",
" self.weirdKs = nn.ParameterList()\n",
" self.weirdQs = nn.ParameterList()\n",
" self.weirdVs = nn.ParameterList()\n",
@@ -204,35 +202,440 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"execution_state": "idle",
"id": "a9dd76f4-96f2-47b5-9bb9-a32a1b478dd4",
"metadata": {},
"outputs": [
{
- "name": "stdout",
+ "name": "stderr",
"output_type": "stream",
"text": [
- "Epoch [0/10000], Loss: 8.3387\n",
- "Epoch [10/10000], Loss: 7.6416\n",
- "Epoch [20/10000], Loss: 11.2689\n",
- "Epoch [30/10000], Loss: 7.0312\n",
- "Epoch [40/10000], Loss: 8.7287\n",
- "Epoch [50/10000], Loss: 7.7182\n"
+ "/tmp/ipykernel_15454/381745885.py:148: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)\n",
+ " ksrc = torch.matmul(src, K.unsqueeze(0).transpose(-2, -1))\n"
]
},
{
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[6], line 11\u001b[0m\n\u001b[1;32m 9\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_fn(outputs, labels)\n\u001b[1;32m 10\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 11\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 13\u001b[0m train_err\u001b[38;5;241m.\u001b[39mappend(loss\u001b[38;5;241m.\u001b[39mitem())\n",
- "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/_tensor.py:581\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 572\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 573\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 574\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 579\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 580\u001b[0m )\n\u001b[0;32m--> 581\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 583\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/autograd/__init__.py:347\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 342\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 344\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 345\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 346\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 347\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 348\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 349\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 350\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 351\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 352\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 353\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 355\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/autograd/graph.py:825\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 823\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m 824\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 825\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 826\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 827\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m 828\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 829\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch [0/10000], Loss: 0.0025\n",
+ "Epoch [10/10000], Loss: 6.4609\n",
+ "Epoch [20/10000], Loss: 11.0729\n",
+ "Epoch [30/10000], Loss: 10.3862\n",
+ "Epoch [40/10000], Loss: 8.3659\n",
+ "Epoch [50/10000], Loss: 8.4364\n",
+ "Epoch [60/10000], Loss: 7.0110\n",
+ "Epoch [70/10000], Loss: 6.2279\n",
+ "Epoch [80/10000], Loss: 14.5876\n",
+ "Epoch [90/10000], Loss: 13.5753\n",
+ "Epoch [100/10000], Loss: 15.5835\n",
+ "Epoch [110/10000], Loss: 14.3249\n",
+ "Epoch [120/10000], Loss: 11.1069\n",
+ "Epoch [130/10000], Loss: 11.6783\n",
+ "Epoch [140/10000], Loss: 10.2477\n",
+ "Epoch [150/10000], Loss: 10.8494\n",
+ "Epoch [160/10000], Loss: 8.3007\n",
+ "Epoch [170/10000], Loss: 6.8133\n",
+ "Epoch [180/10000], Loss: 5.5992\n",
+ "Epoch [190/10000], Loss: 6.9212\n",
+ "Epoch [200/10000], Loss: 5.9311\n",
+ "Epoch [210/10000], Loss: 6.0747\n",
+ "Epoch [220/10000], Loss: 4.9251\n",
+ "Epoch [230/10000], Loss: 3.9548\n",
+ "Epoch [240/10000], Loss: 5.9888\n",
+ "Epoch [250/10000], Loss: 4.9153\n",
+ "Epoch [260/10000], Loss: 6.6282\n",
+ "Epoch [270/10000], Loss: 4.7945\n",
+ "Epoch [280/10000], Loss: 6.8866\n",
+ "Epoch [290/10000], Loss: 5.7963\n",
+ "Epoch [300/10000], Loss: 4.1406\n",
+ "Epoch [310/10000], Loss: 5.8112\n",
+ "Epoch [320/10000], Loss: 6.3739\n",
+ "Epoch [330/10000], Loss: 4.7297\n",
+ "Epoch [340/10000], Loss: 3.6125\n",
+ "Epoch [350/10000], Loss: 4.7553\n",
+ "Epoch [360/10000], Loss: 5.1536\n",
+ "Epoch [370/10000], Loss: 3.3294\n",
+ "Epoch [380/10000], Loss: 4.8955\n",
+ "Epoch [390/10000], Loss: 5.0702\n",
+ "Epoch [400/10000], Loss: 5.5217\n",
+ "Epoch [410/10000], Loss: 4.0543\n",
+ "Epoch [420/10000], Loss: 3.8583\n",
+ "Epoch [430/10000], Loss: 4.4484\n",
+ "Epoch [440/10000], Loss: 6.0914\n",
+ "Epoch [450/10000], Loss: 5.3544\n",
+ "Epoch [460/10000], Loss: 3.1850\n",
+ "Epoch [470/10000], Loss: 4.5308\n",
+ "Epoch [480/10000], Loss: 3.6213\n",
+ "Epoch [490/10000], Loss: 3.3625\n",
+ "Epoch [500/10000], Loss: 3.4060\n",
+ "Epoch [510/10000], Loss: 3.2437\n",
+ "Epoch [520/10000], Loss: 3.9425\n",
+ "Epoch [530/10000], Loss: 3.4496\n",
+ "Epoch [540/10000], Loss: 2.8899\n",
+ "Epoch [550/10000], Loss: 2.5607\n",
+ "Epoch [560/10000], Loss: 3.9549\n",
+ "Epoch [570/10000], Loss: 4.4588\n",
+ "Epoch [580/10000], Loss: 4.3738\n",
+ "Epoch [590/10000], Loss: 3.3019\n",
+ "Epoch [600/10000], Loss: 2.4798\n",
+ "Epoch [610/10000], Loss: 6.1956\n",
+ "Epoch [620/10000], Loss: 4.3365\n",
+ "Epoch [630/10000], Loss: 4.9766\n",
+ "Epoch [640/10000], Loss: 4.2719\n",
+ "Epoch [650/10000], Loss: 5.0380\n",
+ "Epoch [660/10000], Loss: 5.7970\n",
+ "Epoch [670/10000], Loss: 5.8626\n",
+ "Epoch [680/10000], Loss: 4.7593\n",
+ "Epoch [690/10000], Loss: 5.7902\n",
+ "Epoch [700/10000], Loss: 5.9829\n",
+ "Epoch [710/10000], Loss: 6.7365\n",
+ "Epoch [720/10000], Loss: 5.3005\n",
+ "Epoch [730/10000], Loss: 5.8437\n",
+ "Epoch [740/10000], Loss: 3.8711\n",
+ "Epoch [750/10000], Loss: 5.0535\n",
+ "Epoch [760/10000], Loss: 3.7943\n",
+ "Epoch [770/10000], Loss: 4.5757\n",
+ "Epoch [780/10000], Loss: 4.7215\n",
+ "Epoch [790/10000], Loss: 4.4653\n",
+ "Epoch [800/10000], Loss: 5.5769\n",
+ "Epoch [810/10000], Loss: 5.0396\n",
+ "Epoch [820/10000], Loss: 5.2878\n",
+ "Epoch [830/10000], Loss: 5.4599\n",
+ "Epoch [840/10000], Loss: 4.5172\n",
+ "Epoch [850/10000], Loss: 4.8925\n",
+ "Epoch [860/10000], Loss: 5.1588\n",
+ "Epoch [870/10000], Loss: 5.8972\n",
+ "Epoch [880/10000], Loss: 4.9056\n",
+ "Epoch [890/10000], Loss: 4.8735\n",
+ "Epoch [900/10000], Loss: 5.2677\n",
+ "Epoch [910/10000], Loss: 4.3955\n",
+ "Epoch [920/10000], Loss: 5.1297\n",
+ "Epoch [930/10000], Loss: 4.2394\n",
+ "Epoch [940/10000], Loss: 6.8890\n",
+ "Epoch [950/10000], Loss: 5.1845\n",
+ "Epoch [960/10000], Loss: 4.4620\n",
+ "Epoch [970/10000], Loss: 5.1748\n",
+ "Epoch [980/10000], Loss: 4.9878\n",
+ "Epoch [990/10000], Loss: 3.9090\n",
+ "Epoch [1000/10000], Loss: 3.6580\n",
+ "Epoch [1010/10000], Loss: 3.5537\n",
+ "Epoch [1020/10000], Loss: 4.5068\n",
+ "Epoch [1030/10000], Loss: 3.2602\n",
+ "Epoch [1040/10000], Loss: 3.7987\n",
+ "Epoch [1050/10000], Loss: 3.3821\n",
+ "Epoch [1060/10000], Loss: 3.9663\n",
+ "Epoch [1070/10000], Loss: 3.8562\n",
+ "Epoch [1080/10000], Loss: 3.7811\n",
+ "Epoch [1090/10000], Loss: 4.0012\n",
+ "Epoch [1100/10000], Loss: 4.2564\n",
+ "Epoch [1110/10000], Loss: 3.2248\n",
+ "Epoch [1120/10000], Loss: 3.7483\n",
+ "Epoch [1130/10000], Loss: 3.0315\n",
+ "Epoch [1140/10000], Loss: 3.3677\n",
+ "Epoch [1150/10000], Loss: 3.5367\n",
+ "Epoch [1160/10000], Loss: 2.7389\n",
+ "Epoch [1170/10000], Loss: 3.0337\n",
+ "Epoch [1180/10000], Loss: 2.2630\n",
+ "Epoch [1190/10000], Loss: 1.8084\n",
+ "Epoch [1200/10000], Loss: 3.3239\n",
+ "Epoch [1210/10000], Loss: 4.1555\n",
+ "Epoch [1220/10000], Loss: 2.8362\n",
+ "Epoch [1230/10000], Loss: 3.1269\n",
+ "Epoch [1240/10000], Loss: 2.2517\n",
+ "Epoch [1250/10000], Loss: 2.9400\n",
+ "Epoch [1260/10000], Loss: 2.5436\n",
+ "Epoch [1270/10000], Loss: 2.6504\n",
+ "Epoch [1280/10000], Loss: 2.9571\n",
+ "Epoch [1290/10000], Loss: 2.4060\n",
+ "Epoch [1300/10000], Loss: 2.6461\n",
+ "Epoch [1310/10000], Loss: 2.4692\n",
+ "Epoch [1320/10000], Loss: 2.0638\n",
+ "Epoch [1330/10000], Loss: 3.0852\n",
+ "Epoch [1340/10000], Loss: 2.3448\n",
+ "Epoch [1350/10000], Loss: 2.6796\n",
+ "Epoch [1360/10000], Loss: 2.0310\n",
+ "Epoch [1370/10000], Loss: 1.8680\n",
+ "Epoch [1380/10000], Loss: 2.1846\n",
+ "Epoch [1390/10000], Loss: 2.3017\n",
+ "Epoch [1400/10000], Loss: 1.6519\n",
+ "Epoch [1410/10000], Loss: 1.6228\n",
+ "Epoch [1420/10000], Loss: 1.4328\n",
+ "Epoch [1430/10000], Loss: 1.5642\n",
+ "Epoch [1440/10000], Loss: 1.8962\n",
+ "Epoch [1450/10000], Loss: 1.4433\n",
+ "Epoch [1460/10000], Loss: 2.1973\n",
+ "Epoch [1470/10000], Loss: 1.8118\n",
+ "Epoch [1480/10000], Loss: 1.7155\n",
+ "Epoch [1490/10000], Loss: 2.7671\n",
+ "Epoch [1500/10000], Loss: 2.4518\n",
+ "Epoch [1510/10000], Loss: 1.0202\n",
+ "Epoch [1520/10000], Loss: 2.2548\n",
+ "Epoch [1530/10000], Loss: 1.4305\n",
+ "Epoch [1540/10000], Loss: 2.1796\n",
+ "Epoch [1550/10000], Loss: 1.7766\n",
+ "Epoch [1560/10000], Loss: 2.0751\n",
+ "Epoch [1570/10000], Loss: 1.6544\n",
+ "Epoch [1580/10000], Loss: 2.8767\n",
+ "Epoch [1590/10000], Loss: 2.2069\n",
+ "Epoch [1600/10000], Loss: 1.5974\n",
+ "Epoch [1610/10000], Loss: 2.0101\n",
+ "Epoch [1620/10000], Loss: 1.9445\n",
+ "Epoch [1630/10000], Loss: 2.3080\n",
+ "Epoch [1640/10000], Loss: 1.2395\n",
+ "Epoch [1650/10000], Loss: 1.2486\n",
+ "Epoch [1660/10000], Loss: 1.3250\n",
+ "Epoch [1670/10000], Loss: 1.1839\n",
+ "Epoch [1680/10000], Loss: 2.0569\n",
+ "Epoch [1690/10000], Loss: 2.3591\n",
+ "Epoch [1700/10000], Loss: 1.4618\n",
+ "Epoch [1710/10000], Loss: 1.3663\n",
+ "Epoch [1720/10000], Loss: 0.7300\n",
+ "Epoch [1730/10000], Loss: 2.6778\n",
+ "Epoch [1740/10000], Loss: 2.2931\n",
+ "Epoch [1750/10000], Loss: 1.4289\n",
+ "Epoch [1760/10000], Loss: 1.1453\n",
+ "Epoch [1770/10000], Loss: 1.0400\n",
+ "Epoch [1780/10000], Loss: 1.3603\n",
+ "Epoch [1790/10000], Loss: 1.5058\n",
+ "Epoch [1800/10000], Loss: 1.2890\n",
+ "Epoch [1810/10000], Loss: 1.1259\n",
+ "Epoch [1820/10000], Loss: 1.3006\n",
+ "Epoch [1830/10000], Loss: 1.4118\n",
+ "Epoch [1840/10000], Loss: 1.6406\n",
+ "Epoch [1850/10000], Loss: 1.3513\n",
+ "Epoch [1860/10000], Loss: 1.2380\n",
+ "Epoch [1870/10000], Loss: 1.4618\n",
+ "Epoch [1880/10000], Loss: 2.8634\n",
+ "Epoch [1890/10000], Loss: 2.4145\n",
+ "Epoch [1900/10000], Loss: 2.1412\n",
+ "Epoch [1910/10000], Loss: 2.7974\n",
+ "Epoch [1920/10000], Loss: 2.3607\n",
+ "Epoch [1930/10000], Loss: 2.1780\n",
+ "Epoch [1940/10000], Loss: 2.1544\n",
+ "Epoch [1950/10000], Loss: 1.1798\n",
+ "Epoch [1960/10000], Loss: 2.0259\n",
+ "Epoch [1970/10000], Loss: 8.5005\n",
+ "Epoch [1980/10000], Loss: 7.2836\n",
+ "Epoch [1990/10000], Loss: 5.1658\n",
+ "Epoch [2000/10000], Loss: 4.4845\n",
+ "Epoch [2010/10000], Loss: 3.2873\n",
+ "Epoch [2020/10000], Loss: 3.9213\n",
+ "Epoch [2030/10000], Loss: 3.4896\n",
+ "Epoch [2040/10000], Loss: 4.8792\n",
+ "Epoch [2050/10000], Loss: 3.8883\n",
+ "Epoch [2060/10000], Loss: 4.8546\n",
+ "Epoch [2070/10000], Loss: 3.5432\n",
+ "Epoch [2080/10000], Loss: 4.3267\n",
+ "Epoch [2090/10000], Loss: 3.4671\n",
+ "Epoch [2100/10000], Loss: 5.4011\n",
+ "Epoch [2110/10000], Loss: 5.8443\n",
+ "Epoch [2120/10000], Loss: 2.8416\n",
+ "Epoch [2130/10000], Loss: 5.1449\n",
+ "Epoch [2140/10000], Loss: 5.6858\n",
+ "Epoch [2150/10000], Loss: 5.1926\n",
+ "Epoch [2160/10000], Loss: 4.5664\n",
+ "Epoch [2170/10000], Loss: 4.3358\n",
+ "Epoch [2180/10000], Loss: 4.5456\n",
+ "Epoch [2190/10000], Loss: 3.5273\n",
+ "Epoch [2200/10000], Loss: 6.8660\n",
+ "Epoch [2210/10000], Loss: 3.4792\n",
+ "Epoch [2220/10000], Loss: 3.4052\n",
+ "Epoch [2230/10000], Loss: 2.8651\n",
+ "Epoch [2240/10000], Loss: 2.1104\n",
+ "Epoch [2250/10000], Loss: 2.1549\n",
+ "Epoch [2260/10000], Loss: 1.8513\n",
+ "Epoch [2270/10000], Loss: 2.8559\n",
+ "Epoch [2280/10000], Loss: 1.3817\n",
+ "Epoch [2290/10000], Loss: 1.4976\n",
+ "Epoch [2300/10000], Loss: 1.7325\n",
+ "Epoch [2310/10000], Loss: 1.5967\n",
+ "Epoch [2320/10000], Loss: 0.8749\n",
+ "Epoch [2330/10000], Loss: 1.5636\n",
+ "Epoch [2340/10000], Loss: 1.5302\n",
+ "Epoch [2350/10000], Loss: 0.7900\n",
+ "Epoch [2360/10000], Loss: 1.0777\n",
+ "Epoch [2370/10000], Loss: 0.6089\n",
+ "Epoch [2380/10000], Loss: 1.2180\n",
+ "Epoch [2390/10000], Loss: 1.3731\n",
+ "Epoch [2400/10000], Loss: 1.1782\n",
+ "Epoch [2410/10000], Loss: 0.9826\n",
+ "Epoch [2420/10000], Loss: 1.8233\n",
+ "Epoch [2430/10000], Loss: 0.8246\n",
+ "Epoch [2440/10000], Loss: 0.7204\n",
+ "Epoch [2450/10000], Loss: 1.2327\n",
+ "Epoch [2460/10000], Loss: 1.2843\n",
+ "Epoch [2470/10000], Loss: 1.1326\n",
+ "Epoch [2480/10000], Loss: 1.1369\n",
+ "Epoch [2490/10000], Loss: 1.0106\n",
+ "Epoch [2500/10000], Loss: 1.4563\n",
+ "Epoch [2510/10000], Loss: 1.0163\n",
+ "Epoch [2520/10000], Loss: 0.3823\n",
+ "Epoch [2530/10000], Loss: 0.8872\n",
+ "Epoch [2540/10000], Loss: 0.4118\n",
+ "Epoch [2550/10000], Loss: 0.4925\n",
+ "Epoch [2560/10000], Loss: 0.1141\n",
+ "Epoch [2570/10000], Loss: 0.5785\n",
+ "Epoch [2580/10000], Loss: 0.5831\n",
+ "Epoch [2590/10000], Loss: 0.0845\n",
+ "Epoch [2600/10000], Loss: 0.5621\n",
+ "Epoch [2610/10000], Loss: 1.0745\n",
+ "Epoch [2620/10000], Loss: 0.2378\n",
+ "Epoch [2630/10000], Loss: 0.6215\n",
+ "Epoch [2640/10000], Loss: 0.7897\n",
+ "Epoch [2650/10000], Loss: 0.9359\n",
+ "Epoch [2660/10000], Loss: 0.5567\n",
+ "Epoch [2670/10000], Loss: 4.0690\n",
+ "Epoch [2680/10000], Loss: 3.3254\n",
+ "Epoch [2690/10000], Loss: 3.3888\n",
+ "Epoch [2700/10000], Loss: 3.7329\n",
+ "Epoch [2710/10000], Loss: 2.9879\n",
+ "Epoch [2720/10000], Loss: 2.6283\n",
+ "Epoch [2730/10000], Loss: 2.4366\n",
+ "Epoch [2740/10000], Loss: 2.8078\n",
+ "Epoch [2750/10000], Loss: 1.2757\n",
+ "Epoch [2760/10000], Loss: 1.0685\n",
+ "Epoch [2770/10000], Loss: 2.2750\n",
+ "Epoch [2780/10000], Loss: 1.9740\n",
+ "Epoch [2790/10000], Loss: 1.4824\n",
+ "Epoch [2800/10000], Loss: 1.1974\n",
+ "Epoch [2810/10000], Loss: 1.6074\n",
+ "Epoch [2820/10000], Loss: 1.2541\n",
+ "Epoch [2830/10000], Loss: 1.7665\n",
+ "Epoch [2840/10000], Loss: 1.9748\n",
+ "Epoch [2850/10000], Loss: 1.9842\n",
+ "Epoch [2860/10000], Loss: 2.5544\n",
+ "Epoch [2870/10000], Loss: 1.6564\n",
+ "Epoch [2880/10000], Loss: 1.0362\n",
+ "Epoch [2890/10000], Loss: 1.3166\n",
+ "Epoch [2900/10000], Loss: 2.4819\n",
+ "Epoch [2910/10000], Loss: 1.1353\n",
+ "Epoch [2920/10000], Loss: 1.6106\n",
+ "Epoch [2930/10000], Loss: 2.1840\n",
+ "Epoch [2940/10000], Loss: 1.4362\n",
+ "Epoch [2950/10000], Loss: 0.9568\n",
+ "Epoch [2960/10000], Loss: 1.8224\n",
+ "Epoch [2970/10000], Loss: 1.2919\n",
+ "Epoch [2980/10000], Loss: 1.1351\n",
+ "Epoch [2990/10000], Loss: 0.7588\n",
+ "Epoch [3000/10000], Loss: 1.2207\n",
+ "Epoch [3010/10000], Loss: 1.3446\n",
+ "Epoch [3020/10000], Loss: 1.1581\n",
+ "Epoch [3030/10000], Loss: 1.0448\n",
+ "Epoch [3040/10000], Loss: 2.0898\n",
+ "Epoch [3050/10000], Loss: 1.1978\n",
+ "Epoch [3060/10000], Loss: 1.2886\n",
+ "Epoch [3070/10000], Loss: 1.0066\n",
+ "Epoch [3080/10000], Loss: 1.5037\n",
+ "Epoch [3090/10000], Loss: 0.6185\n",
+ "Epoch [3100/10000], Loss: 0.9835\n",
+ "Epoch [3110/10000], Loss: 0.6671\n",
+ "Epoch [3120/10000], Loss: 0.6967\n",
+ "Epoch [3130/10000], Loss: 1.0013\n",
+ "Epoch [3140/10000], Loss: 1.4123\n",
+ "Epoch [3150/10000], Loss: 1.8096\n",
+ "Epoch [3160/10000], Loss: 1.1270\n",
+ "Epoch [3170/10000], Loss: 1.9781\n",
+ "Epoch [3180/10000], Loss: 1.0191\n",
+ "Epoch [3190/10000], Loss: 2.4403\n",
+ "Epoch [3200/10000], Loss: 0.8882\n",
+ "Epoch [3210/10000], Loss: 1.0005\n",
+ "Epoch [3220/10000], Loss: 0.5002\n",
+ "Epoch [3230/10000], Loss: 1.2351\n",
+ "Epoch [3240/10000], Loss: 0.8264\n",
+ "Epoch [3250/10000], Loss: 0.7881\n",
+ "Epoch [3260/10000], Loss: 1.4120\n",
+ "Epoch [3270/10000], Loss: 0.3342\n",
+ "Epoch [3280/10000], Loss: 1.3266\n",
+ "Epoch [3290/10000], Loss: 1.7115\n",
+ "Epoch [3300/10000], Loss: 1.0647\n",
+ "Epoch [3310/10000], Loss: 0.4623\n",
+ "Epoch [3320/10000], Loss: 1.2075\n",
+ "Epoch [3330/10000], Loss: 0.4555\n",
+ "Epoch [3340/10000], Loss: 0.8706\n",
+ "Epoch [3350/10000], Loss: 0.9383\n",
+ "Epoch [3360/10000], Loss: 0.7436\n",
+ "Epoch [3370/10000], Loss: 0.8466\n",
+ "Epoch [3380/10000], Loss: 0.8379\n",
+ "Epoch [3390/10000], Loss: 0.6832\n",
+ "Epoch [3400/10000], Loss: 1.0414\n",
+ "Epoch [3410/10000], Loss: 0.5256\n",
+ "Epoch [3420/10000], Loss: 1.2059\n",
+ "Epoch [3430/10000], Loss: 0.7949\n",
+ "Epoch [3440/10000], Loss: 0.5962\n",
+ "Epoch [3450/10000], Loss: 0.8650\n",
+ "Epoch [3460/10000], Loss: 0.8154\n",
+ "Epoch [3470/10000], Loss: 0.4271\n",
+ "Epoch [3480/10000], Loss: 0.5725\n",
+ "Epoch [3490/10000], Loss: 1.0792\n",
+ "Epoch [3500/10000], Loss: 0.5633\n",
+ "Epoch [3510/10000], Loss: 0.2986\n",
+ "Epoch [3520/10000], Loss: 0.3941\n",
+ "Epoch [3530/10000], Loss: 1.0033\n",
+ "Epoch [3540/10000], Loss: 0.1960\n",
+ "Epoch [3550/10000], Loss: 0.9863\n",
+ "Epoch [3560/10000], Loss: 0.4395\n",
+ "Epoch [3570/10000], Loss: 0.9612\n",
+ "Epoch [3580/10000], Loss: 2.4734\n",
+ "Epoch [3590/10000], Loss: 5.5539\n",
+ "Epoch [3600/10000], Loss: 3.7807\n",
+ "Epoch [3610/10000], Loss: 4.0435\n",
+ "Epoch [3620/10000], Loss: 4.1143\n",
+ "Epoch [3630/10000], Loss: 3.3714\n",
+ "Epoch [3640/10000], Loss: 3.3396\n",
+ "Epoch [3650/10000], Loss: 4.2713\n",
+ "Epoch [3660/10000], Loss: 2.2012\n",
+ "Epoch [3670/10000], Loss: 1.7168\n",
+ "Epoch [3680/10000], Loss: 2.2133\n",
+ "Epoch [3690/10000], Loss: 2.7070\n",
+ "Epoch [3700/10000], Loss: 3.3160\n",
+ "Epoch [3710/10000], Loss: 3.6073\n",
+ "Epoch [3720/10000], Loss: 2.2879\n",
+ "Epoch [3730/10000], Loss: 2.8893\n",
+ "Epoch [3740/10000], Loss: 2.7971\n",
+ "Epoch [3750/10000], Loss: 1.7426\n",
+ "Epoch [3760/10000], Loss: 2.7662\n",
+ "Epoch [3770/10000], Loss: 2.1203\n",
+ "Epoch [3780/10000], Loss: 3.8798\n",
+ "Epoch [3790/10000], Loss: 2.6662\n",
+ "Epoch [3800/10000], Loss: 1.8491\n",
+ "Epoch [3810/10000], Loss: 1.5527\n",
+ "Epoch [3820/10000], Loss: 4.1708\n",
+ "Epoch [3830/10000], Loss: 1.6162\n",
+ "Epoch [3840/10000], Loss: 2.6064\n",
+ "Epoch [3850/10000], Loss: 1.9516\n",
+ "Epoch [3860/10000], Loss: 2.1771\n",
+ "Epoch [3870/10000], Loss: 2.3933\n",
+ "Epoch [3880/10000], Loss: 1.8314\n",
+ "Epoch [3890/10000], Loss: 3.2097\n",
+ "Epoch [3900/10000], Loss: 1.8215\n",
+ "Epoch [3910/10000], Loss: 2.1805\n",
+ "Epoch [3920/10000], Loss: 1.6260\n",
+ "Epoch [3930/10000], Loss: 1.5388\n",
+ "Epoch [3940/10000], Loss: 1.4422\n",
+ "Epoch [3950/10000], Loss: 1.2274\n",
+ "Epoch [3960/10000], Loss: 2.1992\n",
+ "Epoch [3970/10000], Loss: 1.3654\n",
+ "Epoch [3980/10000], Loss: 1.5024\n",
+ "Epoch [3990/10000], Loss: 1.9630\n",
+ "Epoch [4000/10000], Loss: 1.7742\n",
+ "Epoch [4010/10000], Loss: 1.9968\n",
+ "Epoch [4020/10000], Loss: 2.2213\n",
+ "Epoch [4030/10000], Loss: 2.1302\n",
+ "Epoch [4040/10000], Loss: 2.3094\n",
+ "Epoch [4050/10000], Loss: 2.3253\n",
+ "Epoch [4060/10000], Loss: 2.2924\n",
+ "Epoch [4070/10000], Loss: 0.9917\n",
+ "Epoch [4080/10000], Loss: 1.8697\n",
+ "Epoch [4090/10000], Loss: 1.4888\n",
+ "Epoch [4100/10000], Loss: 0.6206\n",
+ "Epoch [4110/10000], Loss: 0.9877\n",
+ "Epoch [4120/10000], Loss: 1.2839\n",
+ "Epoch [4130/10000], Loss: 0.4944\n",
+ "Epoch [4140/10000], Loss: 0.6533\n",
+ "Epoch [4150/10000], Loss: 0.4354\n",
+ "Epoch [4160/10000], Loss: 0.4216\n"
]
}
],
diff --git a/loss b/loss
index 1e0cc2f..08f0413 100644
--- a/loss
+++ b/loss
@@ -1,800 +1,7588 @@
-87.5
-70.5
-59.75
-54.0
-50.75
-48.5
-46.75
-45.75
-44.75
-44.5
-44.25
-43.75
-43.75
-43.5
-43.25
-43.25
-43.0
-42.5
-42.75
-42.5
-42.0
-42.25
-42.0
-41.75
-42.25
-41.75
-41.5
-41.75
-41.25
-41.25
-41.5
-41.25
-41.0
-40.75
-40.75
-40.75
-40.5
-40.5
-40.25
-40.0
-39.75
-39.5
-38.75
-38.5
-37.5
-36.5
-35.5
-34.5
-33.5
-33.0
-32.75
-33.0
-33.0
-32.75
-32.5
-32.25
-31.875
-31.625
-31.5
-31.25
-31.125
-30.875
-30.875
-30.75
-30.625
-30.5
-30.375
-30.375
-30.125
-30.0
-29.875
-29.875
-29.75
-29.5
-29.5
-29.5
-29.375
-29.375
-29.25
-29.125
-29.125
-29.125
-29.0
-28.875
-29.0
-28.875
-28.875
-28.75
-28.75
-28.875
-28.75
-28.75
-28.75
-28.75
-28.75
-28.75
-28.75
-28.75
-28.75
-28.875
-28.75
-28.75
-28.75
-28.625
-28.625
-28.75
-28.75
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.5
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.75
-28.625
-28.625
-28.625
-28.625
-28.625
-28.5
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.625
-28.5
-28.625
-28.625
-28.5
-28.375
-28.375
-28.625
-28.5
-28.5
-28.5
-28.375
-28.5
-28.5
-28.5
-28.5
-28.5
-28.375
-28.375
-28.25
-28.125
-27.875
-27.875
-27.75
-27.875
-27.625
-27.25
-27.5
-26.875
-26.75
-26.5
-26.5
-25.75
-25.375
-24.75
-23.75
-23.375
-22.875
-22.375
-22.25
-21.625
-21.375
-20.875
-20.625
-20.375
-20.0
-19.625
-19.5
-19.0
-18.75
-18.5
-18.125
-17.625
-17.375
-17.0
-16.75
-16.5
-16.375
-15.9375
-15.75
-15.25
-15.0
-14.8125
-14.625
-14.375
-14.125
-13.8125
-13.6875
-13.5625
-13.0625
-13.0
-12.875
-12.5
-12.4375
-12.125
-11.9375
-11.75
-11.625
-11.375
-11.0625
-11.0
-10.8125
-10.75
-10.5625
-10.375
-10.3125
-10.125
-10.0625
-10.0
-9.8125
-9.75
-9.4375
-9.5
-9.4375
-9.3125
-9.1875
-9.0625
-9.0625
-9.0
-8.875
-8.75
-8.8125
-8.625
-8.5625
-8.6875
-8.625
-8.5
-8.4375
-8.125
-8.25
-8.25
-8.0625
-8.0625
-8.0
-7.875
-7.78125
-7.78125
-7.84375
-7.84375
-8.3125
-10.625
-13.3125
-7.90625
-10.625
-7.875
-9.625
-8.8125
-8.4375
-8.9375
-8.1875
-8.125
-8.5
-8.1875
-7.71875
-8.0625
-8.125
-7.65625
-7.625
-7.8125
-7.6875
-7.46875
-7.59375
-7.5625
-7.4375
-7.40625
-7.46875
-7.4375
-7.375
-7.21875
-7.21875
-7.25
-7.1875
-7.0625
-7.15625
-7.1875
-7.0625
-7.0625
-7.0625
-7.09375
-6.9375
-6.875
-6.9375
-6.9375
-6.84375
-6.875
-6.84375
-6.9375
-6.75
-6.78125
-6.84375
-6.6875
-6.65625
-6.625
-6.65625
-6.65625
-6.53125
-6.59375
-6.625
-6.5
-6.53125
-6.5625
-6.5625
-6.53125
-6.4375
-6.5
-6.5
-6.5
-6.4375
-6.46875
-6.46875
-6.46875
-6.375
-6.3125
-6.40625
-6.34375
-6.3125
-6.25
-6.28125
-6.1875
-6.3125
-6.28125
-6.21875
-6.21875
-6.25
-6.15625
-6.09375
-6.1875
-6.28125
-6.09375
-6.21875
-6.09375
-6.09375
-6.09375
-6.15625
-6.0625
-6.03125
-6.0
-6.0
-6.03125
-5.96875
-6.0
-6.03125
-5.96875
-5.875
-5.9375
-5.8125
-5.90625
-5.90625
-5.8125
-5.84375
-5.84375
-5.78125
-5.90625
-5.75
-5.78125
-5.78125
-5.78125
-5.71875
-5.78125
-5.65625
-5.65625
-5.78125
-5.78125
-5.71875
-5.65625
-5.5625
-5.5625
-5.65625
-5.6875
-5.59375
-5.59375
-5.53125
-5.59375
-5.5
-5.59375
-5.53125
-5.5
-5.4375
-5.53125
-5.5625
-5.5
-5.375
-5.46875
-5.53125
-5.40625
-5.4375
-5.53125
-5.40625
-5.40625
-5.3125
-5.375
-5.3125
-5.375
-5.3125
-5.25
-5.3125
-5.1875
-5.21875
-5.28125
-5.09375
-5.21875
-5.125
-5.125
-5.15625
-5.03125
-5.0625
-5.03125
-5.03125
-5.03125
-4.96875
-5.03125
-4.96875
-5.0
-4.9375
-4.875
-4.90625
-4.96875
-4.9375
-4.84375
-4.90625
-4.90625
-4.875
-4.78125
-4.84375
-4.96875
-6.15625
-14.0
-11.1875
-7.625
-12.3125
-7.0625
-8.75
-8.875
-6.65625
-7.5
-8.4375
-7.15625
-6.375
-7.0625
-7.15625
-6.5625
-6.4375
-6.5625
-6.40625
-6.1875
-5.96875
-6.15625
-6.03125
-5.8125
-5.59375
-5.65625
-5.6875
-5.65625
-5.46875
-5.46875
-5.4375
-5.28125
-5.28125
-5.3125
-5.3125
-5.21875
-5.125
-5.125
-5.125
-5.03125
-4.96875
-4.9375
-4.96875
-4.96875
-4.96875
-4.90625
-4.8125
-4.875
-4.8125
-4.8125
-4.75
-4.6875
-4.75
-4.71875
-4.65625
-4.65625
-4.71875
-4.6875
-4.6875
-4.65625
-4.6875
-4.6875
-4.625
-4.65625
-4.65625
-4.59375
-4.65625
-4.53125
-4.53125
-4.5625
-4.5625
-4.59375
-4.53125
-4.5625
-4.5
-4.59375
-4.5625
-4.5625
-4.5625
-4.46875
-4.46875
-4.4375
-4.5
-4.53125
-4.46875
-4.5
-4.375
-4.46875
-4.5
-4.5
-4.46875
-4.375
-4.34375
-4.46875
-4.4375
-4.25
-4.375
-4.40625
-4.34375
-4.40625
-4.375
-4.3125
-4.375
-4.34375
-4.34375
-4.3125
-4.34375
-4.40625
-4.3125
-4.34375
-4.28125
-4.25
-4.3125
-4.3125
-4.28125
-4.1875
-4.25
-4.28125
-4.34375
-4.3125
-4.15625
-4.21875
-4.3125
-4.25
-4.25
-4.1875
-4.28125
-4.15625
-4.21875
-4.25
-4.25
-4.25
-4.15625
-4.1875
-4.15625
-4.25
-4.21875
-4.125
-4.25
-4.1875
-4.25
-4.125
-4.21875
-4.125
-4.125
-4.15625
-4.0625
-4.1875
-4.21875
-4.09375
-4.15625
-4.125
-4.15625
-4.09375
-4.125
-4.03125
-4.09375
-4.125
-4.0625
-4.0
-4.125
-4.0625
-4.0
-4.15625
-4.0625
-4.0625
-4.03125
-4.15625
-4.0625
-4.0625
-3.96875
-4.03125
-4.0
-3.953125
-4.03125
-4.03125
-4.03125
-3.984375
-3.984375
-3.984375
-3.9375
-3.953125
-4.0
-3.9375
-3.96875
-3.9375
-3.984375
-3.96875
-4.03125
-3.96875
-4.0
-3.921875
-4.03125
-3.90625
-3.921875
-3.875
-3.953125
-3.96875
-3.890625
-3.828125
-3.953125
-3.953125
-3.984375
-3.921875
-3.9375
-3.921875
-3.921875
-3.984375
-3.875
-3.84375
-3.828125
-3.890625
-3.875
-3.859375
-3.859375
-3.859375
-3.796875
-3.828125
-3.84375
-3.796875
-3.8125
-3.828125
-3.8125
-3.859375
-3.75
-3.78125
-3.71875
-3.765625
-3.796875
-3.734375
-3.75
-3.765625
-3.765625
-3.71875
-3.75
-3.734375
-3.734375
-3.765625
-3.6875
-3.703125
-3.75
-3.6875
-3.734375
-3.703125
-3.734375
-3.703125
-3.703125
-3.75
-3.71875
-3.640625
-3.6875
-3.6875
-3.703125
-3.671875
-3.609375
-3.734375
-3.671875
-3.65625
-3.65625
-3.59375
-3.6875
-3.6875
-3.546875
-3.625
-3.609375
-3.59375
-3.640625
-3.546875
-3.625
-3.546875
-3.609375
-3.578125
-3.46875
-3.578125
-3.5625
-3.609375
-3.625
-3.546875
-3.5
-3.546875
-3.515625
-3.515625
-3.53125
-3.546875
-3.609375
-3.515625
-3.546875
-3.53125
-3.484375
-3.53125
-3.5
-3.46875
-3.453125
-3.59375
-3.515625
-3.5625
-3.453125
-3.453125
-3.515625
-3.453125
-3.453125
-3.4375
-3.359375
-3.46875
-3.390625
-3.46875
-3.390625
-3.421875
-3.4375
-3.421875
-3.421875
-3.4375
-3.359375
-3.359375
-3.375
-3.375
-3.359375
-3.390625
-3.375
-3.359375
-3.34375
-3.328125
-3.34375
-3.375
-3.34375
-3.3125
-3.296875
-3.375
-3.265625
-3.25
-3.265625
-3.34375
-3.3125
-3.34375
-3.3125
-3.25
-3.28125
-3.328125
-3.296875
-3.203125
-3.296875
-3.375
-3.46875
-4.5625
-10.6875
-11.5625
-12.5
-7.3125
-9.125
-9.4375
-5.96875
-6.625
-7.5
-6.625
+89.41087341308594 1.2944836616516113 4.140069961547852 9.478320121765137 226.46401977539062
+79.52562713623047 0.1247667446732521 1.6021499633789062 5.1414031982421875 203.83050537109375
+72.3871078491211 0.1213218942284584 0.4415026009082794 2.6775705814361572 186.66175842285156
+67.79345703125 0.628143310546875 0.056703727692365646 1.3954437971115112 174.3647918701172
+64.224609375 1.3127467632293701 0.02839338220655918 0.6893171668052673 165.00244140625
+61.02769088745117 2.11295223236084 0.21188797056674957 0.2705949544906616 157.12066650390625
+58.626319885253906 2.9901225566864014 0.545389711856842 0.06147243082523346 150.21148681640625
+56.646629333496094 3.900123119354248 0.973063051700592 0.001802295446395874 144.27072143554688
+54.369544982910156 4.794400691986084 1.4417238235473633 0.04604586958885193 139.1782684326172
+53.34144592285156 5.655455112457275 1.929046392440796 0.15766218304634094 134.8350830078125
+51.55065155029297 6.483419895172119 2.4200546741485596 0.3153894245624542 131.01443481445312
+50.88105010986328 7.276472091674805 2.9086740016937256 0.5034157037734985 127.60140991210938
+50.05271530151367 8.017772674560547 3.382845163345337 0.7096508741378784 124.59398651123047
+49.23320388793945 8.700709342956543 3.830545425415039 0.9213890433311462 121.9515380859375
+49.066341400146484 9.323756217956543 4.249921798706055 1.1340655088424683 119.6395492553711
+47.769195556640625 9.910093307495117 4.649497032165527 1.3441548347473145 117.56342315673828
+47.809688568115234 10.48090648651123 5.041550636291504 1.5587375164031982 115.62959289550781
+47.411128997802734 11.044978141784668 5.4304118156433105 1.7781922817230225 113.80445098876953
+46.585548400878906 11.592522621154785 5.816767692565918 2.0002901554107666 112.0915756225586
+46.65435791015625 12.113420486450195 6.18118953704834 2.216142177581787 110.53479766845703
+46.219085693359375 12.59675121307373 6.521310329437256 2.4176652431488037 109.13500213623047
+45.719078063964844 13.030357360839844 6.832784652709961 2.6087472438812256 107.89032745361328
+45.9516716003418 13.42202091217041 7.115777492523193 2.7829692363739014 106.79149627685547
+45.318580627441406 13.776152610778809 7.371881484985352 2.9433064460754395 105.81124877929688
+45.23729705810547 14.094908714294434 7.605937957763672 3.0911128520965576 104.93331146240234
+44.75959014892578 14.384644508361816 7.81952428817749 3.2271857261657715 104.14584350585938
+44.644107818603516 14.648322105407715 8.014714241027832 3.352524757385254 103.4367446899414
+44.81422424316406 14.88651180267334 8.192360877990723 3.4678783416748047 102.79521179199219
+44.61906433105469 15.106553077697754 8.356470108032227 3.574901819229126 102.21159362792969
+44.398521423339844 15.310723304748535 8.50926685333252 3.6753571033477783 101.67486572265625
+44.681617736816406 15.501893997192383 8.652457237243652 3.7698445320129395 101.17818450927734
+44.25287628173828 15.68527889251709 8.788555145263672 3.8603785037994385 100.71075439453125
+44.36744689941406 15.86087417602539 8.9188232421875 3.947371482849121 100.26752471923828
+43.865875244140625 16.03199005126953 9.045238494873047 4.031771659851074 99.8448715209961
+43.819496154785156 16.19630241394043 9.167619705200195 4.113700866699219 99.44085693359375
+43.722869873046875 16.355485916137695 9.285438537597656 4.192569255828857 99.05418395996094
+43.653221130371094 16.508655548095703 9.399003982543945 4.26910924911499 98.68529510498047
+43.69430160522461 16.65662956237793 9.508933067321777 4.343038082122803 98.33161163330078
+43.627716064453125 16.79853057861328 9.61529541015625 4.414838790893555 97.9936294555664
+43.865360260009766 16.936098098754883 9.718077659606934 4.4843268394470215 97.66673278808594
+43.6976203918457 17.06801986694336 9.818470001220703 4.55241060256958 97.35028076171875
+43.59806823730469 17.19707679748535 9.916793823242188 4.6193623542785645 97.04191589355469
+43.52185821533203 17.322473526000977 10.013404846191406 4.685458183288574 96.74054718017578
+43.41917419433594 17.44563102722168 10.108713150024414 4.7507243156433105 96.44430541992188
+43.277366638183594 17.567068099975586 10.203481674194336 4.815809726715088 96.15127563476562
+43.01314163208008 17.689531326293945 10.297895431518555 4.880768775939941 95.86154174804688
+43.13426208496094 17.811311721801758 10.392668724060059 4.946417331695557 95.57191467285156
+43.431182861328125 17.93536376953125 10.488265991210938 5.012598514556885 95.2811279296875
+43.33560562133789 18.0587158203125 10.585355758666992 5.079974174499512 94.98808288574219
+43.18574142456055 18.18570899963379 10.684093475341797 5.148465633392334 94.69413757324219
+42.7459602355957 18.315410614013672 10.783915519714355 5.2178826332092285 94.3968276977539
+43.096519470214844 18.44548988342285 10.884159088134766 5.287846565246582 94.1004867553711
+42.6866455078125 18.5771427154541 10.985136985778809 5.358206748962402 93.80497741699219
+42.976036071777344 18.707651138305664 11.086485862731934 5.4292097091674805 93.50933837890625
+43.11697006225586 18.83973503112793 11.187674522399902 5.50021505355835 93.21643829345703
+42.67808532714844 18.97121810913086 11.289249420166016 5.571541786193848 92.92369079589844
+42.63383483886719 19.101526260375977 11.390812873840332 5.64274263381958 92.63370513916016
+42.93737030029297 19.23207664489746 11.491907119750977 5.714075565338135 92.34608459472656
+42.754302978515625 19.362140655517578 11.593321800231934 5.785522937774658 92.0596923828125
+42.79202651977539 19.489810943603516 11.694515228271484 5.857180595397949 91.77481842041016
+42.49104309082031 19.61664390563965 11.795486450195312 5.928684234619141 91.49217987060547
+42.830909729003906 19.74285125732422 11.896051406860352 6.000099182128906 91.21210479736328
+42.159950256347656 19.86749267578125 11.996854782104492 6.071923732757568 90.93277740478516
+42.749107360839844 19.990596771240234 12.096658706665039 6.1431379318237305 90.65858459472656
+42.336021423339844 20.11042594909668 12.196396827697754 6.214558124542236 90.38433837890625
+42.3348388671875 20.23137855529785 12.295783042907715 6.285806655883789 90.1128158569336
+42.150115966796875 20.348371505737305 12.394766807556152 6.356782913208008 89.84398651123047
+42.457061767578125 20.466428756713867 12.493119239807129 6.427802562713623 89.57843017578125
+42.125885009765625 20.579002380371094 12.591466903686523 6.498627185821533 89.31507873535156
+42.12937927246094 20.68899917602539 12.688401222229004 6.5691328048706055 89.05392456054688
+42.31658172607422 20.797204971313477 12.785552978515625 6.638766288757324 88.79661560058594
+41.86513137817383 20.89765739440918 12.881269454956055 6.709185600280762 88.54255676269531
+42.25408935546875 20.99154281616211 12.975740432739258 6.778041362762451 88.29133605957031
+41.85453796386719 21.07952880859375 13.068041801452637 6.846499443054199 88.04467010498047
+41.937679290771484 21.14443016052246 13.159239768981934 6.913124084472656 87.80535125732422
+41.81352615356445 21.196611404418945 13.244405746459961 6.976990222930908 87.57551574707031
+41.986900329589844 21.212711334228516 13.325448989868164 7.038490295410156 87.35929870605469
+41.287288665771484 21.178430557250977 13.396129608154297 7.094322204589844 87.15631103515625
+41.66656494140625 21.05164337158203 13.451090812683105 7.139124870300293 86.99600219726562
+41.6973762512207 20.79783821105957 13.484936714172363 7.170678615570068 86.87716674804688
+41.375267028808594 20.345922470092773 13.481538772583008 7.1853461265563965 86.81968688964844
+41.145530700683594 19.490928649902344 13.429767608642578 7.164709091186523 86.86746215820312
+40.98661804199219 17.97794532775879 13.28471851348877 7.102464199066162 87.0831069946289
+40.33439636230469 15.646788597106934 13.01026439666748 6.986405372619629 87.57977294921875
+39.76673126220703 12.255111694335938 12.600862503051758 6.830567836761475 88.2002182006836
+38.79609680175781 8.560340881347656 12.372932434082031 6.853498458862305 88.7684326171875
+38.32127380371094 5.709953308105469 12.79025936126709 7.192134380340576 88.40695190429688
+37.944427490234375 4.126208305358887 13.545401573181152 7.540185928344727 87.39126586914062
+36.917945861816406 3.0872910022735596 13.964518547058105 7.586120128631592 86.50907897949219
+36.96352767944336 1.7151166200637817 14.213419914245605 7.72288703918457 85.7274169921875
+36.2911262512207 1.2215452194213867 14.469675064086914 7.941455364227295 84.7396240234375
+36.07147979736328 1.0918352603912354 14.61170482635498 8.061705589294434 83.94670104980469
+35.376895904541016 0.7247174382209778 14.524091720581055 8.015606880187988 83.68299865722656
+35.48261260986328 0.3255567252635956 14.699923515319824 8.095706939697266 83.37874603271484
+35.48717498779297 0.08286981284618378 15.112645149230957 8.347959518432617 82.78411102294922
+35.17692565917969 0.05676012113690376 15.442191123962402 8.584199905395508 82.2193374633789
+35.239925384521484 0.17837312817573547 15.664982795715332 8.751606941223145 81.77082061767578
+35.63679504394531 0.32611167430877686 15.856014251708984 8.893160820007324 81.35819244384766
+35.1003532409668 0.41278931498527527 16.032421112060547 9.029252052307129 80.9532241821289
+35.02109909057617 0.40153250098228455 16.20699691772461 9.15910530090332 80.5573959350586
+35.041107177734375 0.3160436451435089 16.380008697509766 9.287110328674316 80.17388153076172
+34.9627571105957 0.20669147372245789 16.549646377563477 9.414285659790039 79.79834747314453
+34.42407989501953 0.10652432590723038 16.720090866088867 9.543060302734375 79.42079162597656
+34.85517120361328 0.03761119022965431 16.896657943725586 9.674033164978027 79.03876495361328
+34.085365295410156 0.008536041714251041 17.078664779663086 9.812566757202148 78.64576721191406
+34.62872314453125 0.01195006538182497 17.26502227783203 9.953902244567871 78.2486801147461
+33.91136169433594 0.03385341912508011 17.449636459350586 10.094586372375488 77.85381317138672
+34.22388458251953 0.061150193214416504 17.6286678314209 10.231950759887695 77.4742202758789
+34.37260437011719 0.07996329665184021 17.804039001464844 10.366182327270508 77.10540008544922
+33.894378662109375 0.0883290097117424 17.977975845336914 10.499126434326172 76.74427032470703
+34.02345275878906 0.08192689716815948 18.15253448486328 10.633553504943848 76.3824462890625
+33.71147918701172 0.06587406992912292 18.331321716308594 10.770267486572266 76.01707458496094
+33.85850524902344 0.0444037988781929 18.513572692871094 10.90976619720459 75.64701843261719
+33.54720687866211 0.023779654875397682 18.69814682006836 11.05124568939209 75.27527618408203
+33.51865768432617 0.00937062967568636 18.882957458496094 11.193376541137695 74.90524291992188
+33.280555725097656 0.0026857457123696804 19.067106246948242 11.334820747375488 74.54039001464844
+33.239864349365234 0.0036806021817028522 19.24959945678711 11.475122451782227 74.18112182617188
+33.31248092651367 0.009918763302266598 19.43085479736328 11.614605903625488 73.82701110839844
+33.17473220825195 0.018863555043935776 19.610984802246094 11.754105567932129 73.47702026367188
+33.29364776611328 0.026287322863936424 19.791412353515625 11.893133163452148 73.12947845458984
+33.03371810913086 0.03043908253312111 19.9722957611084 12.03377628326416 72.78225708007812
+33.19105529785156 0.030499860644340515 20.153955459594727 12.174575805664062 72.43688201904297
+32.8181037902832 0.026386110112071037 20.336206436157227 12.316449165344238 72.0916976928711
+32.963043212890625 0.01965807005763054 20.51811981201172 12.458364486694336 71.7500991821289
+32.712074279785156 0.012637785635888577 20.700300216674805 12.600712776184082 71.41024017333984
+32.986488342285156 0.00656893802806735 20.882287979125977 12.742627143859863 71.07305908203125
+32.58635330200195 0.0028721981216222048 21.064359664916992 12.885285377502441 70.7376937866211
+32.94102478027344 0.0016151332529261708 21.24653434753418 13.028000831604004 70.40452575683594
+32.470359802246094 0.0026956661604344845 21.429576873779297 13.171565055847168 70.07215118408203
+32.69208908081055 0.004916773177683353 21.612648010253906 13.31533145904541 69.7420883178711
+32.39582061767578 0.007171346805989742 21.796117782592773 13.459391593933105 69.41326904296875
+32.412872314453125 0.008692401461303234 21.979337692260742 13.603547096252441 69.08718872070312
+32.1214714050293 0.00904088094830513 22.162641525268555 13.747801780700684 68.7632064819336
+32.30231475830078 0.007936190813779831 22.345861434936523 13.892118453979492 68.44144439697266
+31.935733795166016 0.006228095851838589 22.528766632080078 14.036421775817871 68.12208557128906
+31.88172721862793 0.004133602604269981 22.711204528808594 14.18027114868164 67.80609893798828
+31.844497680664062 0.002417945535853505 22.893409729003906 14.324383735656738 67.49248504638672
+31.60824966430664 0.0015473236562684178 23.07514762878418 14.46813678741455 67.18156433105469
+31.878131866455078 0.0014660200104117393 23.256744384765625 14.611873626708984 66.8725814819336
+31.405391693115234 0.0020738362800329924 23.438190460205078 14.755956649780273 66.56602478027344
+31.73788070678711 0.0031235539354383945 23.61880111694336 14.899133682250977 66.26265716552734
+31.564834594726562 0.004135055933147669 23.798917770385742 15.04233455657959 65.96182250976562
+31.513378143310547 0.004640565253794193 23.979106903076172 15.185604095458984 65.66273498535156
+31.791465759277344 0.004678285680711269 24.15906524658203 15.328733444213867 65.3659896850586
+31.439090728759766 0.003959103487432003 24.339319229125977 15.472508430480957 65.07039642333984
+31.632381439208984 0.0033185556530952454 24.519060134887695 15.615784645080566 64.77737426757812
+31.565383911132812 0.002460731193423271 24.698564529418945 15.759139060974121 64.48613739013672
+31.294933319091797 0.00179059780202806 24.878028869628906 15.902580261230469 64.1970443725586
+31.16028594970703 0.0013411246472969651 25.057239532470703 16.045949935913086 63.91023254394531
+31.508045196533203 0.0013276215177029371 25.235685348510742 16.188928604125977 63.62577438354492
+31.120946884155273 0.0015024521853774786 25.41451072692871 16.33216667175293 63.342620849609375
+31.288448333740234 0.0018313348991796374 25.592836380004883 16.475122451782227 63.06159210205078
+31.21213150024414 0.001976747764274478 25.771381378173828 16.618453979492188 62.78228759765625
+30.932899475097656 0.0020824417006224394 25.949636459350586 16.76175880432129 62.504600524902344
+31.092870712280273 0.0020055468194186687 26.127275466918945 16.9046630859375 62.22951889038086
+30.955921173095703 0.0017118442337960005 26.304893493652344 17.047531127929688 61.95620346069336
+30.693370819091797 0.001490370836108923 26.481985092163086 17.190227508544922 61.68517303466797
+31.116661071777344 0.0012909423094242811 26.658267974853516 17.332244873046875 61.41692352294922
+30.98542022705078 0.0012351201148703694 26.834550857543945 17.47446632385254 61.150081634521484
+30.707416534423828 0.0012646751711145043 27.01055335998535 17.616626739501953 60.88502502441406
+30.72037124633789 0.001458437298424542 27.185619354248047 17.757789611816406 60.62295913696289
+30.545612335205078 0.0015852085780352354 27.359872817993164 17.89888572692871 60.363277435302734
+30.674283981323242 0.0017100338591262698 27.533533096313477 18.039365768432617 60.10617446899414
+30.549484252929688 0.001683989423327148 27.706296920776367 18.17936897277832 59.85129928588867
+30.748098373413086 0.0015857040416449308 27.879018783569336 18.319311141967773 59.59825897216797
+30.63336181640625 0.001511172391474247 28.051586151123047 18.459117889404297 59.3466682434082
+30.63526153564453 0.0013833813136443496 28.22370147705078 18.598737716674805 59.097015380859375
+30.412948608398438 0.0012507940409705043 28.3955020904541 18.73834228515625 58.848846435546875
+30.275821685791016 0.0011979987611994147 28.5670108795166 18.87779426574707 58.60262680053711
+30.47374153137207 0.0011826121481135488 28.7379150390625 19.016605377197266 58.35853576660156
+30.162376403808594 0.001212466275319457 28.908878326416016 19.155725479125977 58.11552429199219
+30.25480079650879 0.0011744516668841243 29.078868865966797 19.294124603271484 57.87514877319336
+30.502349853515625 0.0012205285020172596 29.248111724853516 19.43207359313965 57.636749267578125
+30.492717742919922 0.0012172253336757421 29.417865753173828 19.570528030395508 57.39912414550781
+30.37479019165039 0.0011765758972615004 29.58758544921875 19.70891571044922 57.16248321533203
+30.116981506347656 0.0011607077904045582 29.757394790649414 19.847511291503906 56.92728805541992
+30.50902557373047 0.0011379765346646309 29.926559448242188 19.98577880859375 56.693695068359375
+30.001434326171875 0.0011485731229186058 30.095766067504883 20.124135971069336 56.461456298828125
+29.95107650756836 0.0011245683999732137 30.263734817504883 20.261667251586914 56.23173141479492
+30.159881591796875 0.0010896407766267657 30.431089401245117 20.398509979248047 56.00435256958008
+29.959787368774414 0.0011480411048978567 30.5975284576416 20.53488540649414 55.778846740722656
+29.959007263183594 0.0011873360490426421 30.763608932495117 20.670814514160156 55.55533981323242
+29.80963897705078 0.0011442848481237888 30.92879867553711 20.806440353393555 55.333740234375
+29.902612686157227 0.0011268460657447577 31.0925235748291 20.94058609008789 55.115257263183594
+29.79055404663086 0.001121485373005271 31.255573272705078 21.074420928955078 54.898651123046875
+29.851119995117188 0.0010629232274368405 31.416915893554688 21.207061767578125 54.685325622558594
+29.725482940673828 0.001054457388818264 31.57727813720703 21.33881187438965 54.474151611328125
+29.965070724487305 0.0010525118559598923 31.737154006958008 21.470245361328125 54.264686584472656
+29.615781784057617 0.0010211587650701404 31.896963119506836 21.60175895690918 54.056095123291016
+29.830026626586914 0.0010319312568753958 32.055118560791016 21.7319278717041 53.85063552856445
+29.67072296142578 0.0010501735378056765 32.212928771972656 21.861953735351562 53.646400451660156
+29.5562801361084 0.0010397899895906448 32.37065124511719 21.991931915283203 53.443416595458984
+29.60136604309082 0.0010162133257836103 32.52687454223633 22.12078094482422 53.24301528930664
+29.439311981201172 0.0009753975318744779 32.6822509765625 22.24897003173828 53.04462814331055
+29.589420318603516 0.0009602526552043855 32.83667755126953 22.376588821411133 52.84818649291992
+29.268714904785156 0.000961730198469013 32.99033737182617 22.503093719482422 52.65382385253906
+29.41641616821289 0.0009786111768335104 33.14228820800781 22.62871742248535 52.46226119995117
+29.582231521606445 0.0009436884429305792 33.2933349609375 22.753681182861328 52.272377014160156
+29.551715850830078 0.0009577355813235044 33.444068908691406 22.878232955932617 52.08403015136719
+29.514461517333984 0.0009473756072111428 33.59403610229492 23.002323150634766 51.89717102050781
+29.478620529174805 0.0009534275741316378 33.74375915527344 23.126203536987305 51.711463928222656
+29.266042709350586 0.0009708592551760375 33.892547607421875 23.249523162841797 51.52750015258789
+29.20224380493164 0.0009241715888492763 34.04074478149414 23.372196197509766 51.34538269042969
+29.249187469482422 0.0008836665656417608 34.187591552734375 23.494037628173828 51.16524887084961
+29.4305362701416 0.0009282731334678829 34.33364486694336 23.614931106567383 50.986934661865234
+29.332382202148438 0.0009175487793982029 34.47880554199219 23.735349655151367 50.810359954833984
+29.311077117919922 0.0008957480895332992 34.62284469604492 23.85497283935547 50.63581466674805
+29.149982452392578 0.0008662471082061529 34.76615905761719 23.974056243896484 50.46276092529297
+29.136573791503906 0.0008705711807124317 34.908409118652344 24.0921573638916 50.291656494140625
+29.329143524169922 0.0008594110258854926 35.04959487915039 24.209257125854492 50.12259292602539
+29.363718032836914 0.0008835407206788659 35.19102478027344 24.326932907104492 49.953819274902344
+29.142024993896484 0.0008709598914720118 35.3324089050293 24.44451141357422 49.78553771972656
+29.09040641784668 0.0008239200687967241 35.47272872924805 24.56134796142578 49.619178771972656
+29.083938598632812 0.0008635807316750288 35.61294174194336 24.67792320251465 49.45380783081055
+29.130765914916992 0.0008772878791205585 35.753177642822266 24.794771194458008 49.2886962890625
+29.123214721679688 0.0008382525411434472 35.89255905151367 24.910850524902344 49.12544250488281
+29.20360565185547 0.0008664558990858495 36.0309944152832 25.026113510131836 48.96372604370117
+28.932403564453125 0.0008089658222161233 36.16846466064453 25.140771865844727 48.803714752197266
+29.235015869140625 0.0008035087957978249 36.304534912109375 25.254234313964844 48.64595413208008
+29.193073272705078 0.0008018799708224833 36.439476013183594 25.36675453186035 48.489959716796875
+28.971817016601562 0.0007773200632072985 36.57421875 25.479267120361328 48.33475875854492
+28.94454574584961 0.0008015862549655139 36.70803451538086 25.591033935546875 48.18117904663086
+28.892614364624023 0.0007752617821097374 36.840858459472656 25.701826095581055 48.029239654541016
+29.092647552490234 0.0007909027044661343 36.97315979003906 25.812442779541016 47.87839126586914
+29.167362213134766 0.0007597627700306475 37.10464096069336 25.92221450805664 47.72909164428711
+29.108333587646484 0.0007425089133903384 37.235660552978516 26.031570434570312 47.580841064453125
+29.028148651123047 0.0007744873873889446 37.36540222167969 26.14032745361328 47.434120178222656
+28.958667755126953 0.0007677128305658698 37.49473190307617 26.248445510864258 47.28873062133789
+28.998279571533203 0.0007209301693364978 37.62299346923828 26.35577964782715 47.14488983154297
+28.923160552978516 0.0007370186503976583 37.75156784057617 26.463396072387695 47.001224517822266
+29.064014434814453 0.0007247604080475867 37.87917709350586 26.570371627807617 46.85906219482422
+28.973533630371094 0.0007414333522319794 38.00645065307617 26.67698097229004 46.71754837036133
+29.096538543701172 0.0007436627056449652 38.13222885131836 26.782405853271484 46.578407287597656
+28.929466247558594 0.0007531988085247576 38.25749969482422 26.887248992919922 46.440242767333984
+28.983644485473633 0.0007160651148296893 38.38121032714844 26.991077423095703 46.3039665222168
+29.006500244140625 0.0007433315040543675 38.50391387939453 27.094032287597656 46.169307708740234
+28.908424377441406 0.0007056677713990211 38.62665939331055 27.19696044921875 46.03532791137695
+28.82782745361328 0.0006939780432730913 38.74825668334961 27.299007415771484 45.90266418457031
+28.906604766845703 0.0006876896368339658 38.86878967285156 27.40023422241211 45.771549224853516
+28.730758666992188 0.0006801909767091274 38.9893798828125 27.501466751098633 45.64095687866211
+28.82330322265625 0.0006492415559478104 39.108638763427734 27.60159683227539 45.51213836669922
+28.878252029418945 0.0006888331845402718 39.22636795043945 27.700742721557617 45.3851318359375
+28.868412017822266 0.0006464642938226461 39.3428840637207 27.798572540283203 45.259986877441406
+28.707761764526367 0.000665456464048475 39.457740783691406 27.895042419433594 45.13692092895508
+28.686912536621094 0.0006454915856011212 39.572296142578125 27.99158477783203 45.01451110839844
+28.684879302978516 0.0006353760254569352 39.685951232910156 28.087060928344727 44.89350891113281
+28.739269256591797 0.0006468240753747523 39.79867172241211 28.181913375854492 44.773651123046875
+28.561614990234375 0.0006328984163701534 39.91123962402344 28.276498794555664 44.654541015625
+28.75951385498047 0.0006148640532046556 40.02138900756836 28.369279861450195 44.538047790527344
+28.847688674926758 0.0006292716134339571 40.12962341308594 28.46031951904297 44.42408752441406
+28.649333953857422 0.0006137650925666094 40.23809051513672 28.551851272583008 44.30998992919922
+28.741592407226562 0.0006178381736390293 40.34550857543945 28.64224624633789 44.197418212890625
+28.70974349975586 0.0006116210133768618 40.45212173461914 28.7321720123291 44.085933685302734
+28.924192428588867 0.0006113676354289055 40.55678939819336 28.82021141052246 43.97673797607422
+28.59880256652832 0.0006179305491968989 40.66048812866211 28.907875061035156 43.86873245239258
+28.649932861328125 0.0005968677578493953 40.76370620727539 28.99480438232422 43.76168441772461
+28.777069091796875 0.0005764105590060353 40.86499786376953 29.080341339111328 43.65681838989258
+28.66537094116211 0.0005972077487967908 40.9649543762207 29.164566040039062 43.55355453491211
+28.536022186279297 0.0005784269887953997 41.063873291015625 29.247976303100586 43.45183181762695
+28.723634719848633 0.0005952474311925471 41.162349700927734 29.33132553100586 43.35039520263672
+28.66385269165039 0.0005518671241588891 41.259342193603516 29.413036346435547 43.25107955932617
+28.61137580871582 0.0005551231442950666 41.35530090332031 29.494165420532227 43.15269088745117
+28.74795913696289 0.0005644516786560416 41.45063781738281 29.574783325195312 43.05546951293945
+28.664234161376953 0.0005674618878401816 41.54621887207031 29.655317306518555 42.9582633972168
+28.51941680908203 0.0005497397505678236 41.63960266113281 29.73438262939453 42.86326599121094
+28.68932342529297 0.0005598982679657638 41.73111343383789 29.811473846435547 42.77069091796875
+28.71405029296875 0.0005376238841563463 41.820091247558594 29.88689613342285 42.68050003051758
+28.79716682434082 0.0005168390343897045 41.90836715698242 29.961563110351562 42.59117889404297
+28.764022827148438 0.0005310515407472849 41.99540710449219 30.034976959228516 42.503761291503906
+28.55709457397461 0.0005378834321163595 42.08255386352539 30.10877227783203 42.41622543334961
+28.718481063842773 0.0005239526508376002 42.16912841796875 30.181987762451172 42.32923126220703
+28.613672256469727 0.0005147290648892522 42.25476837158203 30.254602432250977 42.2432861328125
+28.536251068115234 0.0005187786882743239 42.33904266357422 30.325929641723633 42.15915298461914
+28.604318618774414 0.0004985079285688698 42.423458099365234 30.39729118347168 42.07495880126953
+28.426992416381836 0.0004904060042463243 42.508277893066406 30.469066619873047 41.990631103515625
+28.60281753540039 0.0005017347866669297 42.59010314941406 30.53841209411621 41.90924835205078
+28.588973999023438 0.0005177854327484965 42.671443939208984 30.607248306274414 41.82851791381836
+28.303844451904297 0.0004886785754933953 42.752227783203125 30.67563819885254 41.748695373535156
+28.618520736694336 0.00048090159543789923 42.83083724975586 30.742429733276367 41.670936584472656
+28.643815994262695 0.0004802040057256818 42.90910339355469 30.80866813659668 41.59376907348633
+28.75080680847168 0.0004807403893209994 42.987361907958984 30.874935150146484 41.516841888427734
+28.618919372558594 0.00048525736201554537 43.06496810913086 30.940576553344727 41.4405517578125
+28.439788818359375 0.0004838869790546596 43.14154052734375 31.005708694458008 41.36534118652344
+28.512432098388672 0.00047223191359080374 43.215789794921875 31.068431854248047 41.29269027709961
+28.73577117919922 0.00047865789383649826 43.28845977783203 31.130178451538086 41.221885681152344
+28.63562774658203 0.00047001364873722196 43.359066009521484 31.190290451049805 41.15259552001953
+28.475868225097656 0.00045817182399332523 43.43000411987305 31.250320434570312 41.0833854675293
+28.49782943725586 0.00046271493192762136 43.50001907348633 31.30975914001465 41.015228271484375
+28.62286376953125 0.0004390076792333275 43.56925582885742 31.368392944335938 40.948184967041016
+28.514118194580078 0.0004585244751069695 43.64011001586914 31.42816734313965 40.879661560058594
+28.63272476196289 0.00046446354826912284 43.7097053527832 31.48774528503418 40.81216049194336
+28.547142028808594 0.0004443485813681036 43.779991149902344 31.54689598083496 40.744319915771484
+28.409915924072266 0.00043041305616497993 43.84819030761719 31.605297088623047 40.678382873535156
+28.596485137939453 0.00044841860653832555 43.91385269165039 31.66082000732422 40.615116119384766
+28.542211532592773 0.0004197142261546105 43.97854232788086 31.715843200683594 40.55292892456055
+28.573589324951172 0.00044204722507856786 44.04113006591797 31.768909454345703 40.49248504638672
+28.544517517089844 0.0004271713551133871 44.1031608581543 31.821922302246094 40.43303298950195
+28.621854782104492 0.00043656217167153955 44.16427993774414 31.873746871948242 40.374446868896484
+28.6638126373291 0.0004312459204811603 44.2255859375 31.925342559814453 40.316043853759766
+28.450164794921875 0.000420932105043903 44.28573989868164 31.976917266845703 40.2583122253418
+28.669795989990234 0.00042816210770979524 44.345375061035156 32.02753829956055 40.20157241821289
+28.400583267211914 0.00040624276152811944 44.40559005737305 32.078914642333984 40.14407730102539
+28.51016616821289 0.0004052302392665297 44.46573257446289 32.12947463989258 40.08686447143555
+28.675472259521484 0.00041150517063215375 44.52363586425781 32.17888641357422 40.03170394897461
+28.56253433227539 0.00040591767174191773 44.58259963989258 32.2290153503418 39.975948333740234
+28.434600830078125 0.000403431651648134 44.64145278930664 32.27926254272461 39.92008590698242
+28.50853729248047 0.0003969169920310378 44.70026779174805 32.32884979248047 39.86470413208008
+28.448169708251953 0.0003991306002717465 44.758827209472656 32.37892150878906 39.80894088745117
+28.523544311523438 0.0003975876607000828 44.815242767333984 32.42695999145508 39.75591278076172
+28.641172409057617 0.0003983053902629763 44.871150970458984 32.47441101074219 39.70289611816406
+28.533416748046875 0.0003852443187497556 44.92626190185547 32.521217346191406 39.651344299316406
+28.556058883666992 0.00038308845250867307 44.979896545410156 32.56724166870117 39.60044860839844
+28.519208908081055 0.0004010615812148899 45.03215026855469 32.61162185668945 39.55179214477539
+28.420068740844727 0.00037689629243686795 45.08323669433594 32.65534973144531 39.50322341918945
+28.420312881469727 0.00038987529114820063 45.13572311401367 32.69918441772461 39.454463958740234
+28.477191925048828 0.00038605151348747313 45.186370849609375 32.74252700805664 39.40705490112305
+28.4501953125 0.0003865202597808093 45.23629379272461 32.785438537597656 39.36018753051758
+28.51528549194336 0.00037750351475551724 45.28706741333008 32.82859420776367 39.3125
+28.376962661743164 0.0003747885348275304 45.33750534057617 32.871742248535156 39.26555633544922
+28.43338394165039 0.00036097457632422447 45.3854866027832 32.91228103637695 39.220703125
+28.29214859008789 0.00036755617475137115 45.43032455444336 32.95004653930664 39.17900466918945
+28.4688720703125 0.000356670148903504 45.47429275512695 32.98843002319336 39.137508392333984
+28.474777221679688 0.0003604388621170074 45.51810073852539 33.02526092529297 39.09675216674805
+28.482196807861328 0.00035427292459644377 45.56031799316406 33.061710357666016 39.05723571777344
+28.408679962158203 0.0003458287683315575 45.601741790771484 33.097007751464844 39.01863479614258
+28.350276947021484 0.00035400650813244283 45.64236831665039 33.13077926635742 38.98117446899414
+28.66348648071289 0.00034807322663255036 45.68247604370117 33.16499710083008 38.943607330322266
+28.461214065551758 0.00035075476625934243 45.72233963012695 33.20002365112305 38.906341552734375
+28.607383728027344 0.0003370194463059306 45.76249313354492 33.23353576660156 38.869197845458984
+28.488933563232422 0.00034702717675827444 45.80184555053711 33.266868591308594 38.832923889160156
+28.554386138916016 0.0003328612365294248 45.83881378173828 33.29906463623047 38.79826736450195
+28.40814971923828 0.0003404106246307492 45.87451171875 33.32961654663086 38.765220642089844
+28.25954246520996 0.0003329061437398195 45.9099006652832 33.3598518371582 38.73208236694336
+28.490936279296875 0.00032925684354268014 45.94490432739258 33.38873291015625 38.69956588745117
+28.445213317871094 0.00033926236210390925 45.97968673706055 33.418643951416016 38.66763687133789
+28.32928466796875 0.00033446226734668016 46.013916015625 33.4487419128418 38.635128021240234
+28.499528884887695 0.0003436672850511968 46.049720764160156 33.47816848754883 38.60212326049805
+28.608379364013672 0.0003296883369330317 46.08338165283203 33.50736999511719 38.57111740112305
+28.550151824951172 0.00032545419526286423 46.11547088623047 33.53593444824219 38.54096984863281
+28.56536865234375 0.0003244446124881506 46.14750289916992 33.56296157836914 38.51141357421875
+28.35757827758789 0.0003244592226110399 46.18011474609375 33.590267181396484 38.48189163208008
+28.590011596679688 0.000324867112794891 46.21340560913086 33.61876678466797 38.45020294189453
+28.597118377685547 0.00031181619851849973 46.24359893798828 33.64467239379883 38.42224884033203
+28.41694450378418 0.00031218057847581804 46.27229309082031 33.66997146606445 38.39564514160156
+28.641176223754883 0.0003055623674299568 46.30002212524414 33.69255065917969 38.37034225463867
+28.420551300048828 0.0003230106667615473 46.323726654052734 33.712650299072266 38.347774505615234
+28.39084243774414 0.0003150381671730429 46.3442497253418 33.73167037963867 38.328224182128906
+28.404468536376953 0.00030749873258173466 46.36588668823242 33.749149322509766 38.308109283447266
+28.627540588378906 0.0003137972089461982 46.38725280761719 33.766944885253906 38.288543701171875
+28.434776306152344 0.0003092889382969588 46.40586471557617 33.78385925292969 38.26985168457031
+28.50838851928711 0.0003014611720573157 46.42359924316406 33.79828643798828 38.25272750854492
+28.3504581451416 0.0002947713655885309 46.44082260131836 33.81328201293945 38.23492431640625
+28.404315948486328 0.00030488980701193213 46.454673767089844 33.82495880126953 38.22248458862305
+28.461776733398438 0.00029863903182558715 46.466522216796875 33.83440399169922 38.21051788330078
+28.55781364440918 0.00030750437872484326 46.47730255126953 33.844783782958984 38.19892501831055
+28.392087936401367 0.0003048052894882858 46.48762893676758 33.85269546508789 38.18907928466797
+28.572731018066406 0.0002895324432756752 46.4947395324707 33.86023712158203 38.18124008178711
+28.58648681640625 0.0002902542764786631 46.49681854248047 33.86278533935547 38.17586898803711
+28.447660446166992 0.0002801929076667875 46.50159454345703 33.86719512939453 38.17045593261719
+28.552913665771484 0.0002806647098623216 46.50617980957031 33.86824417114258 38.16586685180664
+28.259523391723633 0.000287985778413713 46.50706100463867 33.86839294433594 38.162593841552734
+28.429901123046875 0.0002851515600923449 46.50300598144531 33.86765670776367 38.16332244873047
+28.587900161743164 0.000294311175821349 46.505062103271484 33.868492126464844 38.15962600708008
+28.50507354736328 0.0002714733127504587 46.50603103637695 33.87202072143555 38.15716552734375
+28.394044876098633 0.0002808408753480762 46.506011962890625 33.87083435058594 38.15249252319336
+28.514827728271484 0.0002800922957248986 46.494850158691406 33.864192962646484 38.158077239990234
+28.38652801513672 0.0002783247036859393 46.47908020019531 33.85045623779297 38.169898986816406
+28.324886322021484 0.0002595640835352242 46.45940399169922 33.83213424682617 38.18260192871094
+28.53148078918457 0.00027034315280616283 46.44475173950195 33.819515228271484 38.195003509521484
+28.21320343017578 0.0002624048211146146 46.40753936767578 33.79180145263672 38.22148513793945
+28.50330352783203 0.0002672462142072618 46.35752487182617 33.74827194213867 38.25981521606445
+28.400634765625 0.00026736134896054864 46.301536560058594 33.70325469970703 38.30238342285156
+28.35105323791504 0.0002631787792779505 46.22673034667969 33.642417907714844 38.35870361328125
+28.43183135986328 0.0002570388314779848 46.137603759765625 33.569908142089844 38.432533264160156
+28.4847412109375 0.00025069015100598335 46.024444580078125 33.47439956665039 38.521263122558594
+28.44074058532715 0.0002455110370647162 45.925594329833984 33.38779067993164 38.595943450927734
+28.480209350585938 0.00025412486866116524 45.82779312133789 33.300743103027344 38.67420196533203
+28.400245666503906 0.00025417626602575183 45.688438415527344 33.191650390625 38.78103256225586
+28.664348602294922 0.00025412370450794697 45.568233489990234 33.1007080078125 38.86764907836914
+28.312442779541016 0.000250027485890314 45.46977233886719 33.01408767700195 38.930641174316406
+28.422801971435547 0.00024805840803310275 45.38546371459961 32.955108642578125 38.97951126098633
+28.43191146850586 0.00024701154325157404 45.29672622680664 32.87892150878906 39.02630615234375
+28.369956970214844 0.0002510065969545394 45.091339111328125 32.723915100097656 39.15595626831055
+28.3637638092041 0.00025288635515607893 44.81169509887695 32.499359130859375 39.33035659790039
+28.35848617553711 0.00025187182473018765 44.413482666015625 32.19489288330078 39.580604553222656
+28.39171600341797 0.00027116978890262544 43.606040954589844 31.572986602783203 40.148738861083984
+28.193674087524414 0.0002928540052380413 42.402286529541016 30.656965255737305 40.901065826416016
+28.149242401123047 0.0003996325540356338 40.35368347167969 29.042800903320312 42.28657150268555
+27.85210418701172 0.0006054005934856832 37.588951110839844 26.804574966430664 44.103172302246094
+27.706741333007812 0.0011205063201487064 33.65530776977539 24.12294578552246 46.20450973510742
+26.81024169921875 0.0015996521105989814 31.610485076904297 23.385269165039062 45.492706298828125
+26.70167350769043 0.0014337936881929636 29.14063262939453 22.627099990844727 46.44264602661133
+26.4800968170166 0.001413014018908143 32.806236267089844 26.23922348022461 42.874351501464844
+26.629453659057617 0.010625436902046204 17.903522491455078 15.546953201293945 55.38240432739258
+26.92885971069336 0.12534089386463165 36.299190521240234 29.597841262817383 40.49801254272461
+26.219865798950195 0.2534066438674927 32.848121643066406 27.358726501464844 41.17094802856445
+26.216529846191406 0.2740502655506134 16.19249725341797 13.827692031860352 54.92939758300781
+25.87565040588379 0.2200668603181839 17.656909942626953 15.199116706848145 52.347625732421875
+25.120210647583008 0.11522010713815689 26.729331970214844 23.814199447631836 42.76237106323242
+24.769771575927734 0.02050969935953617 26.56079864501953 24.151269912719727 41.847557067871094
+23.560583114624023 0.02736527845263481 14.279704093933105 13.78139877319336 49.01264190673828
+24.324268341064453 0.10499728471040726 6.464270114898682 6.011507034301758 57.99433135986328
+22.83954429626465 0.06346417218446732 11.48595905303955 11.485594749450684 49.30959701538086
+23.287349700927734 0.050499483942985535 16.99857521057129 16.511207580566406 45.72174072265625
+22.73160171508789 0.1316273808479309 15.125231742858887 13.455604553222656 46.33866500854492
+22.21162986755371 0.20693831145763397 10.127638816833496 7.966124057769775 50.14822006225586
+22.362401962280273 0.22864004969596863 8.153997421264648 5.3282012939453125 52.622562408447266
+22.057527542114258 0.2604747712612152 8.802725791931152 5.592729091644287 51.69017028808594
+22.039669036865234 0.20711098611354828 11.040827751159668 7.593510627746582 49.87142562866211
+21.841405868530273 0.12826557457447052 13.673264503479004 10.227747917175293 46.718536376953125
+22.100505828857422 0.13413268327713013 15.202756881713867 11.75343132019043 45.77716827392578
+21.55773162841797 0.2271568775177002 14.48483943939209 11.903746604919434 45.01262283325195
+21.381990432739258 0.29127037525177 11.742265701293945 9.772477149963379 46.34600830078125
+21.360931396484375 0.25438541173934937 9.006109237670898 7.776650905609131 48.23799133300781
+21.156108856201172 0.12474274635314941 8.088827133178711 7.334085941314697 48.89118576049805
+20.932743072509766 0.020399289205670357 9.31131649017334 9.292236328125 46.657718658447266
+20.680063247680664 0.030867841094732285 10.929441452026367 11.948062896728516 44.222076416015625
+20.757701873779297 0.08665759861469269 10.609037399291992 12.026479721069336 43.874595642089844
+20.447370529174805 0.10136115550994873 9.128168106079102 11.046424865722656 44.62647247314453
+20.132484436035156 0.07161642611026764 7.282068729400635 9.715385437011719 45.61766815185547
+19.94980239868164 0.03781922534108162 7.081333160400391 11.100905418395996 43.91503143310547
+19.729373931884766 0.011013786308467388 8.49039077758789 15.13127326965332 41.05959701538086
+19.13616371154785 0.007630822714418173 7.556964874267578 16.323915481567383 39.772804260253906
+19.64910125732422 0.048012688755989075 3.9131672382354736 13.11367416381836 43.97441101074219
+19.323644638061523 0.042817093431949615 10.838120460510254 23.950546264648438 35.56352233886719
+18.769208908081055 0.05211646109819412 3.5717432498931885 15.429579734802246 41.021419525146484
+17.97724723815918 0.017967063933610916 4.842341423034668 18.99751853942871 37.62919616699219
+18.11341667175293 0.0035875029861927032 6.207339763641357 21.460880279541016 35.93375778198242
+17.78685760498047 0.004961112048476934 3.5371739864349365 18.119909286499023 37.86012268066406
+17.46732521057129 0.01642952486872673 3.7592973709106445 19.95846939086914 35.98874282836914
+17.625843048095703 0.02836533449590206 4.283655166625977 21.52991485595703 35.53505325317383
+17.313579559326172 0.024599498137831688 4.267200946807861 21.8309268951416 34.5312385559082
+17.305105209350586 0.009661465883255005 3.0793297290802 19.857627868652344 35.95785903930664
+17.067153930664062 0.004614970181137323 3.8725061416625977 22.90532112121582 33.91291809082031
+16.79264259338379 0.003372375387698412 3.079127788543701 21.66033363342285 34.165565490722656
+16.56806182861328 0.004924009088426828 2.5691604614257812 20.872905731201172 33.98141860961914
+16.509557723999023 0.003453096840530634 3.3581624031066895 23.516637802124023 32.40495300292969
+16.24405288696289 0.003324346151202917 2.546912431716919 21.645254135131836 32.98391342163086
+16.07879066467285 0.006120431236922741 2.0008716583251953 21.15831756591797 33.0939826965332
+16.085216522216797 0.018515028059482574 2.984768867492676 22.988372802734375 31.614839553833008
+15.675837516784668 0.02509821578860283 2.376929998397827 21.815248489379883 31.785337448120117
+15.683835983276367 0.023094775155186653 1.732723593711853 20.000869750976562 33.12538528442383
+15.699033737182617 0.029066719114780426 3.1889097690582275 23.956514358520508 30.007347106933594
+15.394302368164062 0.012931277975440025 1.5790671110153198 18.880760192871094 32.29540252685547
+15.096162796020508 0.008310798555612564 2.0857675075531006 20.776512145996094 31.03923797607422
+15.179811477661133 0.0046249451115727425 2.407567024230957 22.95030403137207 29.78004264831543
+15.200380325317383 0.0031865688506513834 1.2206496000289917 17.479780197143555 32.294776916503906
+15.104679107666016 0.004516741260886192 3.496262311935425 26.43575096130371 27.609149932861328
+14.892333984375 0.0029385534580796957 0.9949859976768494 17.642946243286133 31.8115177154541
+14.72230339050293 0.007681686896830797 1.7335100173950195 21.09756851196289 29.274423599243164
+14.54010009765625 0.015663614496588707 2.2634880542755127 23.527053833007812 27.899188995361328
+14.475801467895508 0.009895652532577515 0.9273936748504639 16.427629470825195 30.964942932128906
+14.129654884338379 0.021215146407485008 1.850778341293335 22.65036964416504 27.638368606567383
+14.055046081542969 0.01620286889374256 1.3453779220581055 21.05855369567871 27.98395538330078
+13.923192977905273 0.0075626070611178875 0.9432594180107117 17.24907875061035 29.194272994995117
+14.101844787597656 0.00825201440602541 1.8228325843811035 22.98409080505371 27.09549903869629
+13.792023658752441 0.0038583334535360336 1.2923856973648071 20.688474655151367 27.37554931640625
+13.818933486938477 0.003185442415997386 0.6683447957038879 16.08180809020996 29.57540512084961
+13.89638900756836 0.0062012444250285625 2.575373411178589 25.17064666748047 24.969057083129883
+13.317001342773438 0.0036314958706498146 0.9488793015480042 17.67890167236328 27.551349639892578
+13.348176956176758 0.0038386150263249874 0.681195080280304 15.91984748840332 28.332210540771484
+13.504878044128418 0.013798634521663189 2.3820338249206543 25.12133026123047 24.148998260498047
+13.125104904174805 0.006969879847019911 0.8040804266929626 17.459915161132812 26.991106033325195
+13.08723258972168 0.005702491849660873 0.6383076906204224 15.861266136169434 27.761615753173828
+13.143928527832031 0.01432669349014759 1.9063133001327515 24.33791732788086 23.667068481445312
+12.748140335083008 0.006312782410532236 0.8977740406990051 19.102195739746094 25.292247772216797
+12.960736274719238 0.0029890916775912046 0.43934234976768494 13.296622276306152 28.015417098999023
+12.69442367553711 0.009924049489200115 1.8414568901062012 25.158145904541016 22.291654586791992
+12.493182182312012 0.006945708300918341 0.9520702958106995 20.074111938476562 23.972095489501953
+12.764747619628906 0.0033897145185619593 0.31835395097732544 12.264530181884766 28.51565170288086
+12.543745040893555 0.019229736179113388 1.4684017896652222 23.129459381103516 22.548696517944336
+12.342719078063965 0.023356415331363678 1.5217885971069336 23.230688095092773 22.238428115844727
+12.237143516540527 0.006886215880513191 0.4492354989051819 13.434186935424805 26.481945037841797
+11.857463836669922 0.0072595421224832535 0.6574143171310425 16.806245803833008 24.036590576171875
+12.129661560058594 0.007864419370889664 1.2648130655288696 22.821842193603516 21.4608211517334
+11.685218811035156 0.0042539117857813835 0.6162277460098267 16.649248123168945 23.47508430480957
+11.828296661376953 0.0054115718230605125 0.432476669549942 14.465930938720703 24.726463317871094
+11.732421875 0.005388631019741297 1.1135931015014648 21.898366928100586 20.936426162719727
+11.342023849487305 0.0063193244859576225 0.6988417506217957 17.948522567749023 22.060747146606445
+11.633755683898926 0.006426772102713585 0.33657196164131165 12.253266334533691 24.989166259765625
+11.539738655090332 0.02250359207391739 1.1809625625610352 21.49508285522461 20.359371185302734
+11.230278015136719 0.020042171701788902 0.7517441511154175 17.829145431518555 21.526531219482422
+11.269291877746582 0.009411735460162163 0.38176488876342773 12.2282133102417 24.11606216430664
+11.175819396972656 0.010771444998681545 0.7673386335372925 19.494401931762695 20.490367889404297
+10.744062423706055 0.0061853136867284775 0.7340256571769714 18.736698150634766 20.008012771606445
+10.940156936645508 0.004801602102816105 0.34611788392066956 12.529172897338867 23.051359176635742
+10.755083084106445 0.004542824812233448 0.6774742603302002 17.8089542388916 19.723129272460938
+10.691060066223145 0.0050443317741155624 0.6730852723121643 17.72230339050293 20.14796257019043
+10.605918884277344 0.004998223856091499 0.42043042182922363 14.446786880493164 21.15486717224121
+10.280235290527344 0.010013362392783165 0.5512047410011292 16.297517776489258 19.54430389404297
+10.527544021606445 0.011787664145231247 0.4667377471923828 14.648416519165039 20.64792251586914
+10.362113952636719 0.01338471844792366 0.5807163715362549 16.191329956054688 19.51226806640625
+10.259082794189453 0.009614572860300541 0.500143826007843 15.063800811767578 19.734779357910156
+10.179882049560547 0.005030939355492592 0.39742016792297363 13.141473770141602 20.476669311523438
+9.981155395507812 0.004638828337192535 0.794457733631134 18.784439086914062 17.513172149658203
+10.029723167419434 0.003650181693956256 0.3159061670303345 10.319122314453125 21.370759963989258
+10.067034721374512 0.0045571476221084595 0.8129144906997681 18.851797103881836 17.15652847290039
+9.809454917907715 0.004704788792878389 0.5482761859893799 13.528681755065918 19.33129119873047
+9.767518997192383 0.005518687888979912 0.37715160846710205 12.343184471130371 19.765186309814453
+9.821378707885742 0.009945937432348728 0.8764588236808777 18.714569091796875 16.641340255737305
+9.798145294189453 0.004011243116110563 0.2792273461818695 8.514620780944824 21.283809661865234
+9.59415054321289 0.010443031787872314 0.8739138841629028 19.197765350341797 15.668316841125488
+9.418681144714355 0.0054527707397937775 0.3708093762397766 11.866429328918457 19.013357162475586
+9.178808212280273 0.004409354645758867 0.37408798933029175 11.3314847946167 18.454246520996094
+9.18424129486084 0.005271884147077799 0.662591278553009 17.69823455810547 15.508811950683594
+9.477020263671875 0.0024883595760911703 0.2741685211658478 7.359949588775635 20.928544998168945
+9.372575759887695 0.010131471790373325 0.9113355875015259 20.125579833984375 14.46837329864502
+9.036276817321777 0.008309534750878811 0.32945162057876587 11.296199798583984 18.018918991088867
+8.937849998474121 0.010266491211950779 0.32766589522361755 10.828840255737305 17.956462860107422
+9.077298164367676 0.014859811402857304 0.6996814012527466 18.815256118774414 14.413175582885742
+9.084192276000977 0.004593776073306799 0.23705452680587769 9.30163288116455 19.165727615356445
+8.730537414550781 0.005749554373323917 0.4862631857395172 14.309078216552734 15.44126033782959
+8.647296905517578 0.004107852932065725 0.4233301281929016 13.924498558044434 15.33063793182373
+8.7614107131958 0.002874293364584446 0.20831161737442017 8.456997871398926 18.714704513549805
+8.720575332641602 0.006356699857860804 0.7133710980415344 18.094343185424805 13.51932430267334
+8.401287078857422 0.0054705520160496235 0.30983221530914307 10.456298828125 16.661052703857422
+8.386844635009766 0.007231353782117367 0.3390016555786133 10.523448944091797 16.32994842529297
+8.447556495666504 0.010334067977964878 0.5856922268867493 16.118886947631836 13.785906791687012
+8.380182266235352 0.004148589447140694 0.2336633801460266 8.529430389404297 17.66853141784668
+8.167287826538086 0.0056786322966217995 0.4831653833389282 14.226655006408691 13.989346504211426
+8.049259185791016 0.0038499077782034874 0.4257371723651886 11.440431594848633 14.858681678771973
+8.094804763793945 0.003499704645946622 0.3156496286392212 9.947546005249023 15.7068452835083
+7.876148223876953 0.0033995183184742928 0.4438004493713379 13.371910095214844 13.591190338134766
+7.962082862854004 0.0033415090292692184 0.3122199475765228 10.956095695495605 15.074238777160645
+7.904879570007324 0.0033399537205696106 0.37677425146102905 11.939742088317871 14.17686653137207
+7.739428520202637 0.0040823896415531635 0.3182576298713684 10.49742603302002 14.395804405212402
+7.711947917938232 0.005697077605873346 0.3462582230567932 10.410996437072754 14.208863258361816
+7.735284328460693 0.00748341204598546 0.34841129183769226 10.817641258239746 14.147290229797363
+7.620716094970703 0.00828245934098959 0.35605189204216003 11.50623893737793 13.263075828552246
+7.703177452087402 0.006094112992286682 0.233809232711792 9.385782241821289 14.757762908935547
+7.776765823364258 0.008140490390360355 0.3872133791446686 13.173210144042969 12.640761375427246
+7.400617599487305 0.005059806630015373 0.24968241155147552 8.290279388427734 14.535019874572754
+7.431169033050537 0.00635011401027441 0.40517565608024597 12.686567306518555 12.221831321716309
+7.246102809906006 0.0032952472101897 0.2800445854663849 8.970000267028809 13.673304557800293
+7.267775058746338 0.0029076586943119764 0.29326122999191284 10.184920310974121 13.237398147583008
+7.2004876136779785 0.0028860168531537056 0.3562837839126587 11.479683876037598 11.903385162353516
+7.382078170776367 0.0026346950326114893 0.20666249096393585 8.203243255615234 14.319108963012695
+7.114922046661377 0.0058182114735245705 0.5655872821807861 14.443612098693848 10.426458358764648
+7.73629093170166 0.003485910128802061 0.1331043392419815 5.28834867477417 17.042858123779297
+8.229644775390625 0.0213896706700325 1.3456169366836548 23.270858764648438 8.281781196594238
+8.279546737670898 0.007329731248319149 0.0886460617184639 2.913235664367676 19.88374137878418
+7.087226867675781 0.03930900618433952 0.497429221868515 10.913969993591309 12.30025577545166
+7.773131370544434 0.03079502284526825 1.3656668663024902 19.564960479736328 9.15528392791748
+7.443472385406494 0.006111220922321081 0.20445628464221954 5.455012321472168 16.42494773864746
+6.910019874572754 0.029857847839593887 0.25476670265197754 7.02294921875 14.048952102661133
+7.1545820236206055 0.04086465761065483 0.7855075597763062 16.163963317871094 9.5851411819458
+6.637293815612793 0.05382828414440155 0.2836848497390747 9.311964988708496 11.989830017089844
+7.048799514770508 0.034096114337444305 0.14718294143676758 5.768387794494629 14.928868293762207
+6.648942947387695 0.003582879202440381 0.30383947491645813 10.810253143310547 10.960819244384766
+6.813208103179932 0.026494881138205528 0.37401100993156433 13.16760540008545 10.00086498260498
+6.608423233032227 0.036704808473587036 0.15843349695205688 7.291318416595459 12.753477096557617
+6.882755279541016 0.03634819760918617 0.16584599018096924 8.258622169494629 12.737399101257324
+6.7360029220581055 0.02927853725850582 0.3401866555213928 13.898775100708008 8.991658210754395
+6.536423206329346 0.00677800877019763 0.1637628972530365 7.902960300445557 12.201643943786621
+6.645792007446289 0.0025842483155429363 0.16687127947807312 7.563205242156982 12.683738708496094
+6.557724952697754 0.003768830792978406 0.34861671924591064 12.464113235473633 9.323756217956543
+6.239143371582031 0.003398811211809516 0.1881965547800064 8.428974151611328 11.00717544555664
+6.331752777099609 0.003465001005679369 0.174311101436615 6.634279251098633 12.317744255065918
+6.161978721618652 0.005910922307521105 0.27863240242004395 9.565751075744629 10.031196594238281
+6.29011344909668 0.00429893983528018 0.3778620958328247 9.822149276733398 10.095605850219727
+6.269426345825195 0.0035508200526237488 0.20762716233730316 7.56121301651001 11.723174095153809
+6.195627212524414 0.004890464246273041 0.1991330087184906 7.738576412200928 11.364215850830078
+6.231379508972168 0.0036136899143457413 0.2795380651950836 10.30377197265625 9.746206283569336
+6.166055202484131 0.0029123772401362658 0.23352114856243134 8.97132682800293 10.517036437988281
+6.009759902954102 0.0028306294698268175 0.16028405725955963 6.936009407043457 11.106677055358887
+5.926337242126465 0.006144433282315731 0.20245040953159332 8.12649154663086 10.157185554504395
+5.787542343139648 0.009350733831524849 0.23889698088169098 9.16025161743164 9.153035163879395
+6.072867393493652 0.005932176951318979 0.15561072528362274 6.539491176605225 11.287739753723145
+5.903041839599609 0.006411904003471136 0.25631165504455566 9.708303451538086 9.056766510009766
+5.92368745803833 0.0032024108804762363 0.1863647699356079 7.5123796463012695 10.501344680786133
+5.749135971069336 0.0028978493064641953 0.24479036033153534 7.822214603424072 9.868867874145508
+5.714685916900635 0.003028944367542863 0.22194364666938782 8.390839576721191 9.262652397155762
+5.756513595581055 0.0025984791573137045 0.16848890483379364 6.116439342498779 10.72169017791748
+5.626316547393799 0.0034674210473895073 0.26866480708122253 8.612086296081543 9.011700630187988
+5.836090087890625 0.0030801326502114534 0.21941277384757996 7.982243537902832 9.767992973327637
+5.684250354766846 0.0027684608940035105 0.17002306878566742 7.258749961853027 9.979669570922852
+5.70443868637085 0.002821578411385417 0.2504521310329437 8.233560562133789 9.36325740814209
+5.591163635253906 0.002940700389444828 0.23854029178619385 8.245356559753418 9.030926704406738
+5.561385154724121 0.003446980845183134 0.1780899614095688 6.228780746459961 10.229394912719727
+5.549156188964844 0.0028748922049999237 0.27035999298095703 9.02712631225586 8.186339378356934
+5.506468772888184 0.0027998853474855423 0.16819578409194946 7.244990825653076 9.2378511428833
+5.341632843017578 0.002659953199326992 0.147178515791893 5.631044387817383 9.969197273254395
+5.461150169372559 0.00375934736803174 0.2518717348575592 8.763413429260254 8.12763786315918
+5.519473075866699 0.0034156832844018936 0.1906280666589737 7.378134250640869 8.988987922668457
+5.474452972412109 0.0026693197432905436 0.11829563230276108 5.176069736480713 10.544636726379395
+5.393915176391602 0.004584948532283306 0.2600984573364258 9.583146095275879 7.6008830070495605
+5.364865303039551 0.0031470186077058315 0.14294388890266418 6.294332504272461 9.750076293945312
+4.990211486816406 0.002789543243125081 0.1830960363149643 6.815423965454102 8.474647521972656
+5.321070671081543 0.002646754728630185 0.19052201509475708 7.3726911544799805 8.757439613342285
+5.293663024902344 0.003054677741602063 0.1739887297153473 7.224431037902832 8.754507064819336
+5.244317054748535 0.0034862083848565817 0.12583008408546448 5.7208638191223145 9.584334373474121
+5.154820442199707 0.0027266424149274826 0.20627917349338531 7.331341743469238 8.099011421203613
+5.1912760734558105 0.0026011373847723007 0.19532491266727448 6.476282119750977 8.722684860229492
+5.098459243774414 0.002569649601355195 0.168925479054451 6.81678581237793 8.47428035736084
+5.068840026855469 0.002372684655711055 0.1425289362668991 6.683045387268066 8.496500015258789
+5.006413459777832 0.00236439798027277 0.15863603353500366 6.020745754241943 8.834212303161621
+5.054623603820801 0.0030017741955816746 0.219291552901268 8.648866653442383 7.106564044952393
+5.283446788787842 0.0021899868734180927 0.09282400459051132 4.669597625732422 10.284239768981934
+5.000825881958008 0.005013875663280487 0.2463216781616211 8.8759765625 6.867989540100098
+5.042392730712891 0.00407671183347702 0.14774729311466217 5.791230201721191 9.213920593261719
+4.978821754455566 0.004061232320964336 0.18689954280853271 6.403209209442139 8.354424476623535
+5.095612525939941 0.003070669947192073 0.2196444272994995 7.863108158111572 7.529740333557129
+4.948460578918457 0.0029409355483949184 0.12332407385110855 4.831955432891846 9.344791412353516
+4.903352737426758 0.002701954683288932 0.2288776934146881 7.032074451446533 7.510995864868164
+4.942036151885986 0.0030455412343144417 0.1908901184797287 6.860835552215576 7.850350856781006
+4.794223308563232 0.0035568666644394398 0.12566398084163666 5.176643371582031 8.8623685836792
+4.797407627105713 0.0024862808641046286 0.2163713425397873 7.532882213592529 7.023167610168457
+4.853455543518066 0.002522163325920701 0.1455077826976776 5.698103427886963 8.302152633666992
+4.746079444885254 0.0028336821123957634 0.1267285794019699 5.926201820373535 7.99560022354126
+4.7381744384765625 0.0031805122271180153 0.1410725712776184 5.839328765869141 7.717927932739258
+4.674890518188477 0.002628915011882782 0.09051024913787842 5.295978546142578 8.129853248596191
+4.724562644958496 0.0027485957834869623 0.15084372460842133 6.348140716552734 7.239356994628906
+4.514912128448486 0.0022766077890992165 0.13645590841770172 5.888221740722656 7.375148773193359
+4.6840620040893555 0.002238156972452998 0.10156215727329254 4.576747894287109 8.592375755310059
+4.706272602081299 0.0035409301053732634 0.1818869560956955 6.709778785705566 6.870720386505127
+4.596179485321045 0.0033792222384363413 0.15386131405830383 5.3513994216918945 7.680607795715332
+4.738674640655518 0.00279362965375185 0.15058064460754395 5.660650253295898 8.057584762573242
+4.586110591888428 0.0025989250279963017 0.22149668633937836 7.487161159515381 6.448708534240723
+4.562119483947754 0.003347471822053194 0.08113744854927063 4.02432918548584 8.863924026489258
+4.617722034454346 0.002524956129491329 0.19338178634643555 7.284975051879883 6.329426288604736
+4.552440643310547 0.0030099786818027496 0.09698759019374847 5.571907043457031 7.630558967590332
+4.609381675720215 0.00252721575088799 0.08751463145017624 4.498779296875 8.458520889282227
+4.435260772705078 0.003852733876556158 0.2035946547985077 6.974410533905029 6.176839351654053
+4.573869228363037 0.003946606535464525 0.12675607204437256 5.168624401092529 8.144904136657715
+4.416800022125244 0.004265359602868557 0.14357642829418182 5.651460647583008 7.289285182952881
+4.503790855407715 0.003540175501257181 0.20795376598834991 7.616572380065918 6.1124491691589355
+4.656459331512451 0.0029196212999522686 0.06019797548651695 3.68886137008667 9.169205665588379
+4.442157745361328 0.002508327830582857 0.18903644382953644 5.959356784820557 6.620009899139404
+4.4181623458862305 0.0026063492987304926 0.21418169140815735 6.611691951751709 6.251286506652832
+4.756228923797607 0.004289204254746437 0.06719846278429031 3.24853515625 9.77233600616455
+4.552772521972656 0.0031062839552760124 0.27324768900871277 7.3967108726501465 6.283786296844482
+4.511558532714844 0.002924931701272726 0.2244272083044052 6.815325736999512 6.546410083770752
+4.371138572692871 0.004534544423222542 0.08705595880746841 3.2596731185913086 8.873507499694824
+4.431758880615234 0.003094869665801525 0.181365504860878 6.1389665603637695 6.4398193359375
+4.378685474395752 0.0030796071514487267 0.17745469510555267 6.822868824005127 5.910989761352539
+4.536359786987305 0.005234676878899336 0.05642205476760864 3.050165891647339 9.412261962890625
+4.34744930267334 0.003481661668047309 0.16433194279670715 6.197441101074219 6.0549139976501465
+4.310267925262451 0.007111624348908663 0.24744945764541626 7.117215156555176 5.616497039794922
+4.486175060272217 0.003082074224948883 0.057575467973947525 3.0424082279205322 9.336939811706543
+4.262304306030273 0.008678397163748741 0.1446814388036728 6.152136325836182 6.787021636962891
+4.420350074768066 0.009129521436989307 0.32306361198425293 8.711939811706543 4.66732931137085
+4.497283458709717 0.0029059741646051407 0.06015750393271446 2.7898457050323486 9.52900218963623
+4.037683486938477 0.002838472370058298 0.11146567761898041 4.502479553222656 7.006784439086914
+4.267483711242676 0.0028541935607790947 0.2595292627811432 7.744819641113281 4.930395603179932
+4.243061542510986 0.0055841789580881596 0.07358638942241669 3.5965781211853027 8.252680778503418
+4.215356826782227 0.005168249364942312 0.1030576303601265 4.4014692306518555 7.839999675750732
+4.262025833129883 0.0030566134955734015 0.300423264503479 8.164140701293945 4.930785179138184
+4.092628479003906 0.0043916055001318455 0.09071867167949677 4.2131218910217285 7.311185836791992
+4.088506698608398 0.0039374311454594135 0.08201576769351959 4.008776664733887 7.282649993896484
+4.220724582672119 0.002709218766540289 0.15584687888622284 6.0168561935424805 6.006224155426025
+4.017931938171387 0.00308048608712852 0.09394729137420654 5.009641647338867 6.412130832672119
+4.228157043457031 0.0037906530778855085 0.06141233444213867 3.841980457305908 7.969844341278076
+4.055586814880371 0.01029894594103098 0.12905536592006683 6.344673156738281 5.495707035064697
+4.070987701416016 0.007767384871840477 0.07966601103544235 5.157752990722656 6.367331504821777
+3.9698939323425293 0.003364402102306485 0.05469954386353493 3.9123995304107666 7.305241584777832
+3.8856594562530518 0.003025277052074671 0.08978775888681412 4.973645210266113 6.192290306091309
+4.147519111633301 0.0028251011390239 0.12190266698598862 5.974234104156494 5.523310661315918
+4.001870155334473 0.004998221527785063 0.08009206503629684 3.593902349472046 7.504701614379883
+4.020960807800293 0.0031459920573979616 0.12440972775220871 4.595152854919434 6.656228542327881
+4.100412368774414 0.0034923155326396227 0.21333442628383636 6.974575519561768 4.876172065734863
+4.125182628631592 0.003053730120882392 0.060527410358190536 3.440274715423584 7.80585241317749
+4.065999984741211 0.002979004755616188 0.07328955829143524 3.24782395362854 7.741550922393799
+4.002124786376953 0.005633350927382708 0.20011167228221893 7.086719989776611 4.721829891204834
+3.8693182468414307 0.0030544160399585962 0.11908145248889923 4.437038421630859 6.440038681030273
+3.9558234214782715 0.007004801649600267 0.07446420937776566 3.4687142372131348 7.5227837562561035
+3.864755630493164 0.0045569417998194695 0.1455206722021103 5.346372604370117 5.735536098480225
+3.8635945320129395 0.00321844220161438 0.16876280307769775 5.856716632843018 5.432933807373047
+3.9086718559265137 0.0031566147226840258 0.08272600173950195 3.7678701877593994 7.1300740242004395
+3.8832035064697266 0.003115599974989891 0.08583451807498932 3.978689193725586 6.6009321212768555
+3.854678153991699 0.007413072977215052 0.11716894805431366 5.048062324523926 5.597292423248291
+3.711003303527832 0.005980146117508411 0.09102536737918854 4.081021785736084 5.890818119049072
+3.858119487762451 0.003093192121013999 0.06649590283632278 3.3348793983459473 7.0024285316467285
+3.7374043464660645 0.0031026601791381836 0.10412769019603729 4.604864120483398 5.844191074371338
+3.8142952919006348 0.002943781204521656 0.10995770245790482 4.439831256866455 6.005554676055908
+3.746781826019287 0.002806703094393015 0.10338383913040161 4.430851936340332 6.134519577026367
+3.7370901107788086 0.002799863228574395 0.09783603996038437 4.236443996429443 6.042731761932373
+3.638068914413452 0.0030070904176682234 0.12703171372413635 4.817493438720703 5.4331817626953125
+3.7775650024414062 0.0029798184987157583 0.08125223964452744 4.137610912322998 6.3009185791015625
+3.7040610313415527 0.002892489777877927 0.08562615513801575 4.160228252410889 5.8010454177856445
+3.640101432800293 0.003310216823592782 0.08770270645618439 4.272277355194092 5.612390518188477
+3.7089710235595703 0.0034094159491360188 0.060064759105443954 3.5967109203338623 6.445226669311523
+3.730228900909424 0.005486132577061653 0.0914030596613884 3.8973658084869385 6.197526454925537
+3.61708927154541 0.006049109157174826 0.13327063620090485 4.769344806671143 4.9507155418396
+3.6072258949279785 0.0031582536175847054 0.08953819423913956 3.246190309524536 6.486567974090576
+3.691938877105713 0.004250473342835903 0.06875713914632797 3.2135584354400635 6.621501445770264
+3.8479034900665283 0.004015885293483734 0.1624048948287964 5.911251068115234 4.829733371734619
+3.6330933570861816 0.006610908079892397 0.07231476157903671 2.967723846435547 6.655341625213623
+3.7616281509399414 0.004659843631088734 0.08399839699268341 3.3769843578338623 6.660409927368164
+3.616089344024658 0.003636438399553299 0.17335444688796997 5.205489635467529 4.765808582305908
+3.675203800201416 0.0033451616764068604 0.08408781886100769 3.6898179054260254 6.266198635101318
+3.6059610843658447 0.003719310974702239 0.06969363242387772 3.300612449645996 6.6775407791137695
+3.504384994506836 0.0033766913693398237 0.13877034187316895 5.170243263244629 4.698305606842041
+3.490999221801758 0.0031280475668609142 0.08878140896558762 4.211969375610352 5.49625825881958
+3.713881492614746 0.003342251991853118 0.06485294550657272 2.838362693786621 6.987827777862549
+3.6763916015625 0.004172794055193663 0.15635570883750916 5.201360702514648 4.872502326965332
+3.6063473224639893 0.003924650605767965 0.09494492411613464 3.765831708908081 6.035249710083008
+3.5982933044433594 0.0033773593604564667 0.07919351756572723 2.7751312255859375 6.675537109375
+3.481092929840088 0.004583392292261124 0.13612787425518036 4.573653697967529 4.9074249267578125
+3.445239543914795 0.002915862249210477 0.08893831074237823 3.8049020767211914 5.4971795082092285
+3.591792583465576 0.0038070837035775185 0.06193932145833969 2.7907073497772217 6.945589542388916
+3.56485652923584 0.0030768883880227804 0.14629170298576355 5.292057037353516 4.603923320770264
+3.4883522987365723 0.003040596144273877 0.10344705730676651 3.900078296661377 5.567991733551025
+3.4862260818481445 0.003394094994291663 0.0634879320859909 3.122750759124756 6.461058139801025
+3.4118576049804688 0.0032999212853610516 0.08424855768680573 4.119073390960693 5.404160976409912
+3.5965259075164795 0.004267698619514704 0.1190248653292656 4.814492702484131 4.851497650146484
+3.421065330505371 0.0029377152677625418 0.03866046667098999 2.355386972427368 6.674127578735352
+3.3671226501464844 0.004205759614706039 0.05915672332048416 3.0840985774993896 5.727874755859375
+3.4740657806396484 0.006648975424468517 0.15612797439098358 5.487617492675781 4.127128601074219
+3.4488372802734375 0.003099747933447361 0.05980391055345535 2.5781426429748535 6.664844036102295
+3.3855843544006348 0.0031042522750794888 0.05957839637994766 2.9297776222229004 6.213932037353516
+3.5468578338623047 0.004218119662255049 0.16687722504138947 5.428776741027832 4.180073261260986
+3.344606876373291 0.0036043371073901653 0.07279661297798157 3.109410047531128 5.87086820602417
+3.475771903991699 0.004749497398734093 0.05597677081823349 2.5479395389556885 6.58293342590332
+3.4508414268493652 0.003613067325204611 0.1430790275335312 4.451412677764893 4.594465732574463
+3.2862420082092285 0.0041093179024755955 0.08505091071128845 3.442136526107788 5.269561290740967
+3.3943448066711426 0.006006599869579077 0.054321084171533585 2.8475167751312256 6.281555652618408
+3.3526368141174316 0.0035416516475379467 0.08202483505010605 3.7224643230438232 5.456127166748047
+3.20436954498291 0.00429072231054306 0.09498089551925659 4.060601234436035 4.587981224060059
+3.2761144638061523 0.003767400048673153 0.0658489391207695 2.8880279064178467 5.818134307861328
+3.274711847305298 0.004315816331654787 0.07121054083108902 3.0099809169769287 5.7406392097473145
+3.3436708450317383 0.0051453206688165665 0.12677070498466492 4.447368144989014 4.515134334564209
+3.2649030685424805 0.0030894738156348467 0.07796426862478256 3.359341621398926 5.4585041999816895
+3.3054022789001465 0.0034032270777970552 0.06040085852146149 2.786766767501831 5.970472812652588
+3.2278404235839844 0.0032370039261877537 0.08122998476028442 3.863785743713379 4.576094150543213
+3.256497859954834 0.0034845792688429356 0.05805782601237297 3.114543914794922 5.624128818511963
+3.177884578704834 0.0041701653972268105 0.06665591895580292 2.997831344604492 5.419617176055908
+3.1635842323303223 0.004319665487855673 0.08741921931505203 4.046016693115234 4.724268436431885
+3.171114921569824 0.0033207503147423267 0.06760267913341522 2.950728416442871 5.464331150054932
+3.2067084312438965 0.003355759894475341 0.06961900740861893 2.8078343868255615 5.692041873931885
+3.12355375289917 0.0039661116898059845 0.0894145742058754 3.7076210975646973 4.623778343200684
+3.265043258666992 0.003733430989086628 0.07318845391273499 2.912494421005249 5.347103118896484
+3.262259006500244 0.0035183702129870653 0.06074177473783493 2.3052923679351807 6.061827182769775
+3.2132444381713867 0.003828640328720212 0.09963701665401459 3.382863998413086 4.876228332519531
+3.005119800567627 0.0040326910093426704 0.0863996371626854 3.3202128410339355 4.536506175994873
+3.167473316192627 0.007123690564185381 0.053742509335279465 2.473215103149414 6.079121112823486
+3.108879566192627 0.0038715973496437073 0.09217671304941177 3.321317434310913 4.763245582580566
+3.1234688758850098 0.0037663751281797886 0.09287262707948685 3.6576592922210693 4.4840850830078125
+3.183880567550659 0.003402303671464324 0.0530875138938427 1.9482022523880005 6.105441093444824
+3.241779327392578 0.004418161232024431 0.07467315346002579 2.6891276836395264 5.552256107330322
+3.287527561187744 0.004849465563893318 0.1246655210852623 4.318915843963623 3.9870784282684326
+3.1805925369262695 0.004313038196414709 0.043377310037612915 2.327378034591675 6.145829200744629
+2.999110460281372 0.0036860478576272726 0.052183669060468674 2.6667613983154297 5.3761467933654785
+3.3489863872528076 0.004311475437134504 0.14096853137016296 4.779354572296143 4.170400142669678
+3.111968517303467 0.0033360139932483435 0.05747460201382637 2.482990026473999 5.482112884521484
+3.1140847206115723 0.0034160816576331854 0.04947276413440704 2.3409833908081055 5.835331439971924
+3.1583871841430664 0.0039042183198034763 0.10559233278036118 4.264029502868652 4.146317958831787
+3.159808874130249 0.003221445484086871 0.07091596722602844 3.1396079063415527 4.963814735412598
+3.1297035217285156 0.0032077557407319546 0.043910760432481766 1.8936823606491089 6.1297407150268555
+3.040250301361084 0.006496834568679333 0.08565906435251236 3.241042375564575 4.532522201538086
+3.141169548034668 0.008061669766902924 0.07729090005159378 3.430117130279541 4.729026794433594
+3.1196255683898926 0.004061215557157993 0.039317935705184937 2.1510589122772217 5.77971887588501
+3.0497653484344482 0.003938451409339905 0.07665277272462845 3.097047805786133 4.997101306915283
+3.140139579772949 0.003770282957702875 0.07605922222137451 3.921403408050537 4.398036479949951
+2.8398003578186035 0.004804858937859535 0.05442988872528076 2.3795926570892334 4.843491554260254
+3.112104892730713 0.005039186682552099 0.048925045877695084 2.3331234455108643 5.639102458953857
+3.110804796218872 0.003995477221906185 0.08929577469825745 3.836273193359375 4.353917121887207
+2.924044132232666 0.004126922693103552 0.0628284215927124 2.428565502166748 4.979681491851807
+3.1582932472229004 0.0038003744557499886 0.04724796488881111 1.992488145828247 6.121822834014893
+3.104402780532837 0.004990256391465664 0.11218078434467316 3.3839190006256104 4.423742294311523
+3.123415946960449 0.0037313809152692556 0.08660350739955902 3.632493734359741 4.7041497230529785
+3.1387531757354736 0.007443233858793974 0.047974828630685806 2.0710947513580322 6.136469841003418
+2.998552083969116 0.004506608005613089 0.061663806438446045 3.0277559757232666 5.000101566314697
+3.032261371612549 0.004061452113091946 0.09908313304185867 3.925293207168579 4.298740386962891
+3.291110038757324 0.0035826368257403374 0.04788228124380112 2.271711826324463 6.236457824707031
+2.9930057525634766 0.0038603898137807846 0.07056397944688797 3.0619382858276367 4.5303239822387695
+2.9354066848754883 0.003449269337579608 0.06549648940563202 3.177375555038452 4.328341960906982
+2.9109153747558594 0.0038319791201502085 0.046397339552640915 2.20625901222229 5.08178186416626
+2.9745163917541504 0.0034261371474713087 0.042425476014614105 2.484790802001953 5.148287296295166
+3.0159547328948975 0.009269052185118198 0.07592888176441193 3.616568088531494 3.8783249855041504
+2.952932596206665 0.008276104927062988 0.046228744089603424 2.01478910446167 5.4847917556762695
+2.9446609020233154 0.018315186724066734 0.052869267761707306 2.432708740234375 4.912925720214844
+3.010986328125 0.005756989121437073 0.06577035784721375 3.2642831802368164 4.474074840545654
+2.9398193359375 0.003873254172503948 0.059247955679893494 2.991224765777588 4.578798294067383
+3.0505900382995605 0.0070896209217607975 0.04571245238184929 2.0217723846435547 5.791703224182129
+2.8599119186401367 0.0037037900183349848 0.0736708864569664 3.0377895832061768 4.223706245422363
+2.878131151199341 0.004147302359342575 0.08075904846191406 2.6733458042144775 4.49491024017334
+2.9548587799072266 0.003632900072261691 0.04816479980945587 2.2388734817504883 5.317519187927246
+2.9056472778320312 0.003698532935231924 0.062210436910390854 2.511735439300537 4.881228446960449
+2.759993553161621 0.0037515289150178432 0.0824543759226799 3.0272164344787598 3.9170923233032227
+2.861217737197876 0.004216532688587904 0.048813596367836 2.3038339614868164 4.9899373054504395
+3.02466082572937 0.0035278128925710917 0.05206703394651413 2.3054840564727783 5.101314067840576
+2.982332706451416 0.0050221639685332775 0.08184700459241867 3.046076536178589 4.121588230133057
+2.9545764923095703 0.005044796969741583 0.05399644747376442 2.59721302986145 5.075848579406738
+2.9211246967315674 0.005364621058106422 0.0458696223795414 2.0711143016815186 5.448758125305176
+2.9641120433807373 0.007244155276566744 0.08809457719326019 3.559206247329712 3.818974018096924
+2.753434419631958 0.0037417167332023382 0.05891556292772293 2.3676979541778564 4.788445949554443
+2.931212902069092 0.006867275573313236 0.04853897914290428 2.0460269451141357 5.5698137283325195
+2.9024925231933594 0.004732044413685799 0.08477185666561127 3.044097900390625 4.123342990875244
+2.838250160217285 0.00440927455201745 0.06377160549163818 2.549842596054077 4.625210762023926
+3.023838996887207 0.004173725377768278 0.04689762368798256 2.036574125289917 5.515024662017822
+2.845604419708252 0.0040662228129804134 0.061095111072063446 2.816474199295044 4.301832675933838
+2.898275852203369 0.0044302744790911674 0.06818193197250366 2.874605894088745 4.456329345703125
+2.927192211151123 0.003658401081338525 0.04850434511899948 1.9647951126098633 5.199476718902588
+2.771251678466797 0.0036515831016004086 0.05379704758524895 2.664764404296875 4.389348030090332
+2.8461360931396484 0.0037962014321237803 0.06014730781316757 2.997746706008911 4.169159412384033
+2.8778207302093506 0.00344187137670815 0.03829202800989151 1.6440832614898682 5.393142223358154
+2.920891284942627 0.004242445342242718 0.060475949198007584 2.8075287342071533 4.444828987121582
+2.802421808242798 0.003593332367017865 0.05505753308534622 2.7691290378570557 4.170833587646484
+2.744640350341797 0.004242972005158663 0.040998682379722595 1.8532767295837402 5.031521320343018
+2.8194165229797363 0.0035489683505147696 0.04434378817677498 2.200214385986328 4.857565879821777
+2.847310781478882 0.004374573938548565 0.07955310493707657 3.164440393447876 3.861476421356201
+2.8657889366149902 0.0034720872063189745 0.05025531351566315 2.6607799530029297 4.606383323669434
+2.903852701187134 0.0032488370779901743 0.03732670471072197 1.7389962673187256 5.555969715118408
+2.8165390491485596 0.006980944890528917 0.06915063410997391 3.3100292682647705 3.8075122833251953
+2.7559943199157715 0.005628931801766157 0.05162537097930908 2.5585646629333496 4.536006927490234
+2.790321111679077 0.0036067310720682144 0.04203970357775688 1.8595324754714966 4.916510581970215
+2.6207685470581055 0.0035365517251193523 0.048175081610679626 2.439568519592285 3.8292930126190186
+2.727255344390869 0.0032855914905667305 0.04157482460141182 1.9478943347930908 4.7129597663879395
+2.727947473526001 0.003589158644899726 0.051555559039115906 2.2142062187194824 4.650673866271973
+2.789233922958374 0.004283679183572531 0.06399817019701004 3.2793917655944824 3.833841562271118
+2.7069082260131836 0.0038005849346518517 0.058253344148397446 2.1286866664886475 4.454507827758789
+2.693154811859131 0.0043337466195225716 0.042496830224990845 1.7294397354125977 4.9286346435546875
+2.584723711013794 0.0039991880767047405 0.07281101495027542 2.3888895511627197 3.8980579376220703
+2.706537961959839 0.0042309085838496685 0.07653128355741501 2.3386504650115967 4.088850021362305
+2.8052594661712646 0.0037936836015433073 0.052013907581567764 1.9083212614059448 5.010652542114258
+2.635127067565918 0.003730293596163392 0.06629397720098495 2.4495885372161865 4.111302375793457
+2.6127076148986816 0.003786667948588729 0.05767180398106575 2.4761624336242676 3.8708178997039795
+2.839179515838623 0.004074582830071449 0.03876173868775368 1.605021595954895 5.184404373168945
+2.731159210205078 0.003920785617083311 0.06082357466220856 2.2856743335723877 4.394491672515869
+2.775982141494751 0.0043953703716397285 0.0672219842672348 3.0237462520599365 3.845118761062622
+2.7589948177337646 0.0034503270871937275 0.04604542627930641 2.059605836868286 4.881852149963379
+2.746783494949341 0.0034988021943718195 0.04604955017566681 2.027456760406494 4.708268165588379
+2.667135238647461 0.0038094655610620975 0.05148385465145111 2.632999897003174 3.764784574508667
+2.6544370651245117 0.003852618159726262 0.042649950832128525 1.9571924209594727 4.645623683929443
+2.7414989471435547 0.0038950699381530285 0.03855666518211365 1.8035188913345337 4.758266448974609
+2.730651378631592 0.004798911977559328 0.05157983675599098 2.414614200592041 4.211737632751465
+2.563467264175415 0.003561323508620262 0.048736266791820526 2.23748517036438 4.060364723205566
+2.709516763687134 0.0033916118554770947 0.037867650389671326 1.8248947858810425 4.934010028839111
+2.6334762573242188 0.0034250272437930107 0.050329726189374924 2.1336841583251953 4.348479747772217
+2.6713624000549316 0.004122957121580839 0.06842261552810669 2.8325541019439697 3.725266456604004
+2.6091132164001465 0.0036081895232200623 0.04036779701709747 1.969881296157837 4.639393329620361
+2.570242404937744 0.0033989306539297104 0.05114512890577316 1.687565565109253 4.485219478607178
+2.674955129623413 0.0036942949518561363 0.06915824115276337 2.5515506267547607 3.4399399757385254
+2.699103355407715 0.0033611590042710304 0.04269937425851822 1.7808815240859985 4.663976669311523
+2.755049228668213 0.00340100284665823 0.034598998725414276 1.4618651866912842 5.355345726013184
+2.854917049407959 0.00685743847861886 0.09508749842643738 3.441176414489746 3.337836742401123
+2.622025489807129 0.003442391287535429 0.04500451683998108 2.1640141010284424 4.458320617675781
+2.6431055068969727 0.005117010325193405 0.039289407432079315 1.3727002143859863 5.293365001678467
+2.634615898132324 0.003793602343648672 0.09685014188289642 2.950404405593872 3.2943689823150635
+2.5939974784851074 0.0038844505324959755 0.06027236208319664 2.3119654655456543 4.174847602844238
+2.545408010482788 0.0065870024263858795 0.04114694893360138 1.4162859916687012 4.864618301391602
+2.6157469749450684 0.0033094934187829494 0.049475252628326416 2.208897113800049 4.135314464569092
+2.6674022674560547 0.00404640706256032 0.05418572202324867 2.4789843559265137 3.7404110431671143
+2.6361489295959473 0.004149979911744595 0.03651346638798714 1.6408871412277222 4.878879547119141
+2.569129467010498 0.007562628947198391 0.039670009166002274 2.297872304916382 4.089655876159668
+2.69686222076416 0.007515566423535347 0.04973044991493225 2.309924364089966 3.912517547607422
+2.4919331073760986 0.003325424389913678 0.0329357385635376 1.5791151523590088 4.422548770904541
+2.6106338500976562 0.003250404726713896 0.03997819870710373 1.7784934043884277 4.494634628295898
+2.681220531463623 0.004126886371523142 0.06825223565101624 2.410102605819702 3.8213839530944824
+2.544677257537842 0.0034036333672702312 0.04727752506732941 1.732767939567566 4.19491720199585
+2.655485153198242 0.004058376420289278 0.042752400040626526 1.5809624195098877 4.893096446990967
+2.686535596847534 0.003685815492644906 0.06176057085394859 2.397688150405884 4.019949436187744
+2.561006546020508 0.004435483831912279 0.06022118404507637 2.2303390502929688 3.87768292427063
+2.5930628776550293 0.0076799108646810055 0.04345919191837311 1.826198935508728 4.4461283683776855
+2.6115400791168213 0.005536852870136499 0.044550806283950806 1.6259909868240356 4.549592971801758
+2.5032620429992676 0.00350664253346622 0.07122528553009033 2.545977830886841 3.4380526542663574
+2.6247825622558594 0.0039122579619288445 0.04964355379343033 1.9654040336608887 4.504370212554932
+2.5936450958251953 0.004128345288336277 0.04100419580936432 1.6973294019699097 4.5062689781188965
+2.619861602783203 0.0043419124558568 0.04569840803742409 2.501434326171875 3.6765296459198
+2.5015735626220703 0.0028489872347563505 0.03613751381635666 1.843826174736023 4.009599685668945
+2.60699462890625 0.002815572777763009 0.030142102390527725 1.2440508604049683 5.026559829711914
+2.5411007404327393 0.009448736906051636 0.051775138825178146 2.6020824909210205 3.317840576171875
+2.5242395401000977 0.007079616654664278 0.03804130107164383 1.6112960577011108 4.3071393966674805
+2.519033432006836 0.0032220594584941864 0.037310708314180374 1.374830961227417 4.5875935554504395
+2.4891750812530518 0.0032067708671092987 0.043529536575078964 1.7843191623687744 4.070054531097412
+2.5483226776123047 0.0032287132926285267 0.052378635853528976 2.5833237171173096 3.526844024658203
+2.572981834411621 0.0044029708951711655 0.0407964363694191 1.4564894437789917 4.915266513824463
+2.6363015174865723 0.0037635641638189554 0.05368131399154663 2.2035722732543945 4.263815402984619
+2.6314191818237305 0.004323188681155443 0.07455090433359146 2.692847967147827 3.4692487716674805
+2.555131435394287 0.0051567270420491695 0.04246911406517029 1.5918909311294556 4.75978422164917
+2.4828803539276123 0.005651574116200209 0.042375776916742325 1.3841018676757812 4.444149017333984
+2.418520927429199 0.0035368010867387056 0.05992000550031662 2.1842598915100098 3.40932559967041
+2.4358558654785156 0.0034280018880963326 0.047418538480997086 2.032257080078125 3.7482874393463135
+2.6482863426208496 0.003310038708150387 0.033931270241737366 1.1587599515914917 5.132447242736816
+2.5509519577026367 0.005265361629426479 0.06224296987056732 2.4103829860687256 3.452644109725952
+2.513950824737549 0.004041231237351894 0.051578160375356674 2.162423610687256 3.797978639602661
+2.5320897102355957 0.0035441035870462656 0.047170110046863556 1.398155927658081 4.675901889801025
+2.431278705596924 0.0029340151231735945 0.04683765769004822 1.8074461221694946 3.8398635387420654
+2.456554412841797 0.0030676028691232204 0.04977904260158539 1.9911065101623535 3.7377371788024902
+2.5790610313415527 0.0030134597327560186 0.040071040391922 1.9862688779830933 4.344679355621338
+2.356088638305664 0.002784278243780136 0.051222532987594604 1.666967511177063 4.131004810333252
+2.523869276046753 0.0033429046161472797 0.04402744770050049 2.09661865234375 3.986339807510376
+2.5065672397613525 0.0032069343142211437 0.04278001934289932 1.7758700847625732 4.092862129211426
+2.494385004043579 0.00269434810616076 0.046423789113759995 1.937769889831543 3.7228896617889404
+2.3997697830200195 0.003020466072484851 0.04314090311527252 1.4359066486358643 4.242915630340576
+2.4581446647644043 0.002904457738623023 0.038413506001234055 1.5607322454452515 4.148782253265381
+2.4386088848114014 0.005742170847952366 0.057464394718408585 2.1449320316314697 3.547173500061035
+2.3916077613830566 0.004032738972455263 0.05156542733311653 1.713897466659546 3.9141407012939453
+2.465345621109009 0.0028215665370225906 0.03520321473479271 1.260400414466858 4.5363898277282715
+2.4015841484069824 0.00307678896933794 0.052986279129981995 2.3805019855499268 3.2115957736968994
+2.3339500427246094 0.0029821465723216534 0.03666987642645836 1.5019526481628418 3.8771872520446777
+2.4758193492889404 0.0028578294441103935 0.0347294844686985 1.300605297088623 4.467398166656494
+2.511214256286621 0.0039935666136443615 0.051826201379299164 2.2735378742218018 3.491729497909546
+2.458592414855957 0.003073636209592223 0.03946320340037346 1.6297355890274048 4.258171081542969
+2.485107421875 0.003169798757880926 0.04519651457667351 1.5643532276153564 4.326968193054199
+2.4283571243286133 0.0033829077146947384 0.05773899331688881 1.9254313707351685 3.3634419441223145
+2.461801528930664 0.0059256721287965775 0.04295135661959648 1.664072871208191 3.949794054031372
+2.376267671585083 0.005602389574050903 0.036426059901714325 1.2842202186584473 4.366249084472656
+2.3519821166992188 0.003377574263140559 0.04276569187641144 1.4965215921401978 3.7819015979766846
+2.399146795272827 0.007164651528000832 0.05702567845582962 2.1098172664642334 3.234917163848877
+2.4056899547576904 0.0028305265586823225 0.03223207965493202 1.2281806468963623 4.490728855133057
+2.342156410217285 0.002732998225837946 0.034596748650074005 1.34830641746521 4.170935153961182
+2.4190726280212402 0.004020275082439184 0.05773290619254112 2.6010987758636475 2.8510124683380127
+2.4707069396972656 0.0025608448777347803 0.031100096181035042 1.1373515129089355 4.666764259338379
+2.438410520553589 0.0031674765050411224 0.03917409107089043 1.4551695585250854 4.200824737548828
+2.4965643882751465 0.0048653981648385525 0.06806357949972153 2.7175066471099854 3.1590452194213867
+2.3846635818481445 0.002817735308781266 0.03754686191678047 1.5315589904785156 4.13535737991333
+2.492642879486084 0.0041009788401424885 0.03220030665397644 1.0136330127716064 4.8736348152160645
+2.3857967853546143 0.0036587058566510677 0.06818058341741562 2.240257740020752 3.055220365524292
+2.3148510456085205 0.003106445074081421 0.049283385276794434 1.7918492555618286 3.6744842529296875
+2.438943386077881 0.003039513947442174 0.03304256498813629 1.0750720500946045 4.7443695068359375
+2.438138961791992 0.0026645248290151358 0.04895652085542679 1.988869309425354 3.517951011657715
+2.368319034576416 0.002752550644800067 0.05028781667351723 2.133683204650879 3.115309715270996
+2.359513282775879 0.0039008958265185356 0.03207525983452797 0.9432986378669739 4.730887413024902
+2.3239645957946777 0.004154971335083246 0.03826741501688957 1.540661096572876 3.9640309810638428
+2.4077773094177246 0.010982641950249672 0.07245726883411407 2.898311138153076 2.655761480331421
+2.380845785140991 0.002641191938892007 0.030545024201273918 1.2625963687896729 4.410107135772705
+2.447862148284912 0.0030506004113703966 0.027873992919921875 1.0303677320480347 4.796267509460449
+2.346329689025879 0.005266845691949129 0.05902921035885811 2.064391851425171 3.2973506450653076
+2.3786187171936035 0.0037602921947836876 0.05835731327533722 1.673459529876709 3.355708599090576
+2.361053943634033 0.0062765926122665405 0.036000628024339676 1.2091141939163208 4.3644609451293945
+2.303083896636963 0.007444282993674278 0.034160397946834564 1.1563472747802734 4.267794132232666
+2.3341729640960693 0.002957879798486829 0.0683479830622673 2.1913301944732666 3.1014251708984375
+2.3778984546661377 0.0031091715209186077 0.05128026381134987 1.9986209869384766 3.5939993858337402
+2.4328513145446777 0.002777090994641185 0.03156604617834091 1.1494348049163818 4.654418468475342
+2.2410268783569336 0.004084778483957052 0.04074988514184952 1.6238112449645996 3.5180306434631348
+2.3446240425109863 0.004120929166674614 0.04774622619152069 1.8552920818328857 3.0212674140930176
+2.302886962890625 0.0026804266963154078 0.03176456317305565 1.1421442031860352 4.208995819091797
+2.3898744583129883 0.0024343449622392654 0.03351247310638428 0.9579389095306396 4.4809675216674805
+2.3087563514709473 0.00729542737826705 0.05499657988548279 1.7178606986999512 2.948906898498535
+2.2608914375305176 0.003884886158630252 0.038320012390613556 1.1929481029510498 4.076016426086426
+2.2632155418395996 0.0027362469118088484 0.03307126834988594 1.0845621824264526 4.310342311859131
+2.2599003314971924 0.0028003910556435585 0.053840991109609604 1.9649450778961182 3.3162055015563965
+2.327003002166748 0.004679434932768345 0.058663222938776016 2.3121564388275146 3.1303699016571045
+2.3465723991394043 0.010364525020122528 0.04191291704773903 1.2028400897979736 4.255751609802246
+2.307645320892334 0.003186587244272232 0.03614930436015129 1.211390495300293 4.342501640319824
+2.264760971069336 0.00920629408210516 0.07286576181650162 1.8858767747879028 3.0940661430358887
+2.377887487411499 0.007210341282188892 0.06167817860841751 1.8205755949020386 3.380838632583618
+2.4375431537628174 0.00268200458958745 0.031932272017002106 0.9409888386726379 4.964348793029785
+2.2378485202789307 0.0023151638451963663 0.03576350584626198 1.4606449604034424 3.729510545730591
+2.291609287261963 0.004934191703796387 0.055724773555994034 2.3011608123779297 2.767727851867676
+2.2971343994140625 0.0032098889350891113 0.03353678807616234 1.1902542114257812 4.161663055419922
+2.314072608947754 0.0030371625907719135 0.0344683937728405 0.9854190349578857 4.344897747039795
+2.236647129058838 0.0036512461956590414 0.05718836933374405 1.938767433166504 3.131957769393921
+2.2082695960998535 0.0033646179363131523 0.045979250222444534 1.8425333499908447 3.1001713275909424
+2.423971652984619 0.010587813332676888 0.03491740673780441 0.9047387838363647 4.715253829956055
+2.215503215789795 0.0027077090926468372 0.04314613714814186 1.5802171230316162 3.319383144378662
+2.249948024749756 0.006181684322655201 0.057764336466789246 1.882026195526123 3.3268423080444336
+2.239001989364624 0.005453973077237606 0.03610214218497276 1.325811743736267 3.7524006366729736
+2.3340511322021484 0.0030526933260262012 0.04421893134713173 1.3306461572647095 4.2182722091674805
+2.231844902038574 0.0028650667518377304 0.03811786323785782 1.556913137435913 3.3273189067840576
+2.2774100303649902 0.0025335592217743397 0.03392484784126282 1.4397481679916382 3.657686233520508
+2.167196273803711 0.002632143907248974 0.0318673774600029 1.4058359861373901 3.5925488471984863
+2.093230724334717 0.003104468574747443 0.03373262286186218 1.12394380569458 3.637085437774658
+2.1983914375305176 0.004657271783798933 0.04228837788105011 1.349942684173584 3.606926202774048
+2.2159104347229004 0.003282479476183653 0.04497027024626732 1.5727295875549316 3.4309864044189453
+2.247868061065674 0.0037555501330643892 0.03840131685137749 1.2225770950317383 3.7504639625549316
+2.172943115234375 0.004826077260077 0.03753090277314186 1.175981044769287 3.612846851348877
+2.1234776973724365 0.0029198029078543186 0.03987564146518707 1.3401210308074951 3.3210389614105225
+2.1564438343048096 0.0029864327516406775 0.05358732491731644 1.2556946277618408 3.505455255508423
+2.1347742080688477 0.003247757675126195 0.03872721269726753 1.3706066608428955 3.287766218185425
+2.1488518714904785 0.002574303885921836 0.034985899925231934 1.1903598308563232 3.784533977508545
+2.1555404663085938 0.0030162304174154997 0.03372637927532196 1.2994458675384521 3.676513910293579
+2.1120429039001465 0.004453466739505529 0.04382149130105972 1.4259856939315796 3.044074535369873
+2.058901786804199 0.0026582975406199694 0.03382646664977074 1.3200361728668213 3.403494119644165
+2.1075222492218018 0.0028233909979462624 0.03206013888120651 1.0635387897491455 3.729771375656128
+2.069685935974121 0.002652629278600216 0.034202706068754196 1.0902458429336548 3.5267539024353027
+2.236600399017334 0.0037037241272628307 0.054111965000629425 1.5210316181182861 3.259451150894165
+2.175290107727051 0.002588720293715596 0.043214473873376846 1.5933318138122559 3.4258480072021484
+2.0986104011535645 0.003966072108596563 0.03788541629910469 1.1826196908950806 3.5898079872131348
+2.2678260803222656 0.0028766111936420202 0.03871288150548935 1.3212076425552368 3.647177219390869
+2.2753872871398926 0.0027095982804894447 0.053412601351737976 1.8372398614883423 3.0198514461517334
+2.2366747856140137 0.0024698078632354736 0.033206116408109665 0.8405522108078003 4.176301002502441
+2.1851918697357178 0.0033204101491719484 0.03744015842676163 0.9180333018302917 3.7278265953063965
+2.109461784362793 0.003710351651534438 0.05159679800271988 1.5206294059753418 2.932878255844116
+2.091607093811035 0.002807261422276497 0.03635536506772041 1.29523503780365 3.323042392730713
+2.2669167518615723 0.004525776021182537 0.03372511267662048 1.0610965490341187 4.229312896728516
+2.1116156578063965 0.002687490079551935 0.04248853772878647 1.6052544116973877 3.1072285175323486
+2.1247079372406006 0.0042716143652796745 0.04914692044258118 1.8762387037277222 3.075709581375122
+2.1148948669433594 0.0024735876359045506 0.032694388180971146 0.9883989691734314 3.6932225227355957
+2.2194457054138184 0.0023140537086874247 0.031842902302742004 0.9871858358383179 4.15209436416626
+2.189404249191284 0.004938631318509579 0.046774886548519135 1.6983813047409058 2.9311039447784424
+2.0576043128967285 0.003091466845944524 0.04796670749783516 1.1396613121032715 3.188037633895874
+2.183551788330078 0.002562710316851735 0.030235623940825462 0.9860764145851135 4.036547660827637
+2.074195384979248 0.0026863208040595055 0.03622780367732048 1.2344083786010742 3.222790241241455
+2.0273501873016357 0.003407067386433482 0.04928719252347946 1.5469322204589844 2.962002754211426
+2.1987574100494385 0.002670794492587447 0.03582346439361572 0.9870648980140686 3.7103540897369385
+2.177978277206421 0.002623870735988021 0.03316587954759598 1.1316277980804443 3.7934374809265137
+2.156982183456421 0.0029500341042876244 0.04429018497467041 1.4550679922103882 2.961836576461792
+2.1354012489318848 0.002899357583373785 0.03507395461201668 1.1142547130584717 3.6550257205963135
+2.0967390537261963 0.003030765801668167 0.034577008336782455 1.00502610206604 3.675137519836426
+2.1266884803771973 0.0027816162910312414 0.0484829805791378 1.4444752931594849 2.9733316898345947
+2.0773274898529053 0.0028772063087671995 0.03753618150949478 1.019474744796753 3.4825456142425537
+2.0613293647766113 0.002857058774679899 0.033400557935237885 1.0214990377426147 3.6245226860046387
+2.101311206817627 0.002334228716790676 0.04129748046398163 1.2562506198883057 3.353362798690796
+2.109092950820923 0.002446036087349057 0.036863792687654495 1.198508620262146 3.4128870964050293
+2.0352349281311035 0.0025205365382134914 0.03537081927061081 1.2533187866210938 3.1413919925689697
+2.0175302028656006 0.0025464468635618687 0.029427755624055862 1.0299806594848633 3.327662944793701
+2.0732312202453613 0.00450223358348012 0.032475538551807404 1.1370079517364502 3.378011703491211
+2.047935962677002 0.006476367823779583 0.03501071035861969 1.30342698097229 3.0554959774017334
+2.0472331047058105 0.0031144225504249334 0.03225068747997284 1.0220688581466675 3.292229175567627
+2.037952423095703 0.002550107892602682 0.03468968719244003 1.2345871925354004 3.3129186630249023
+2.0779409408569336 0.00336584378965199 0.04699378460645676 1.2263178825378418 3.1445696353912354
+2.12491512298584 0.0028706141747534275 0.032252080738544464 0.8409675359725952 3.6625449657440186
+2.075326919555664 0.0037982866633683443 0.04104778543114662 1.161537528038025 3.34298038482666
+2.0257315635681152 0.004183824174106121 0.047669362276792526 1.1191487312316895 3.282121419906616
+2.04803466796875 0.0026684754993766546 0.04004077613353729 1.405788779258728 2.9639158248901367
+2.000779390335083 0.0063936058431863785 0.032515477389097214 0.9569733142852783 3.555738687515259
+2.1048083305358887 0.0030613220296800137 0.03658844530582428 1.2335222959518433 3.3009603023529053
+1.9603981971740723 0.003500081831589341 0.045651230961084366 1.2264643907546997 2.974466323852539
+2.0772228240966797 0.0034556188620626926 0.03868253529071808 0.8740466833114624 3.7457807064056396
+2.1007068157196045 0.0025362928863614798 0.045224033296108246 1.210296869277954 3.124760627746582
+1.988915205001831 0.0036162466276437044 0.033602457493543625 1.1365253925323486 3.1104843616485596
+2.0438995361328125 0.003913302440196276 0.032956622540950775 0.9740473031997681 3.5457897186279297
+2.074009895324707 0.003222889266908169 0.036020658910274506 1.1368389129638672 3.347747325897217
+2.042757511138916 0.009342671371996403 0.04551079869270325 1.3302743434906006 3.192140817642212
+2.0940511226654053 0.004660379607230425 0.04204981029033661 1.1409173011779785 3.3237500190734863
+2.0222878456115723 0.0027636075392365456 0.035268452018499374 1.2557463645935059 3.1420233249664307
+2.0713820457458496 0.004940434824675322 0.03195541724562645 0.947910487651825 3.736569404602051
+2.091095447540283 0.0025862406473606825 0.04241332411766052 1.5971612930297852 2.9288418292999268
+2.0013110637664795 0.003807229921221733 0.039320360869169235 1.3874377012252808 2.8850901126861572
+2.0983939170837402 0.0023770476691424847 0.027681011706590652 0.8112154006958008 4.074620246887207
+2.048715114593506 0.0029937901999801397 0.06699385493993759 1.3816354274749756 2.826719045639038
+2.0390195846557617 0.0022390384692698717 0.03866223618388176 1.1827622652053833 2.866259813308716
+2.0182414054870605 0.005021973047405481 0.030349383130669594 0.7254211902618408 3.7190730571746826
+1.9898415803909302 0.00239353789947927 0.02829514630138874 0.7579469084739685 3.5950310230255127
+2.055687427520752 0.00959597248584032 0.06015177071094513 1.6924858093261719 2.558087110519409
+1.9583816528320312 0.00294878869317472 0.038453418761491776 1.260643720626831 3.126235008239746
+2.046086549758911 0.005483720451593399 0.03238185867667198 1.0314668416976929 3.8711719512939453
+2.086740493774414 0.002744189230725169 0.04256097227334976 1.5891189575195312 3.2168655395507812
+1.9665849208831787 0.00362738617695868 0.05271757021546364 1.5110375881195068 2.5766007900238037
+1.8902838230133057 0.0024597516749054193 0.027843696996569633 0.8531237840652466 3.26442289352417
+1.9418704509735107 0.00232839141972363 0.04289733245968819 0.7222417593002319 3.6026670932769775
+1.884800910949707 0.004837582819163799 0.04250917211174965 1.0938478708267212 2.649519443511963
+1.9913784265518188 0.002634241245687008 0.03674421086907387 1.0802438259124756 3.0437216758728027
+2.0421664714813232 0.003956069704145193 0.028504066169261932 0.8323503136634827 3.6803171634674072
+2.01747989654541 0.0029326179064810276 0.034436456859111786 1.2289540767669678 3.294433355331421
+2.0142552852630615 0.002847598632797599 0.049432173371315 1.658443808555603 2.701625108718872
+1.9715979099273682 0.0025149218272417784 0.04073650762438774 1.1019867658615112 3.3180408477783203
+1.9786993265151978 0.0027694765012711287 0.03419185057282448 0.8651126027107239 3.6431758403778076
+1.9455593824386597 0.0023712185211479664 0.043872736394405365 1.1349248886108398 2.8902976512908936
+1.9682867527008057 0.0021889167837798595 0.04341185465455055 1.3090028762817383 2.6651828289031982
+1.8953598737716675 0.00413536699488759 0.029543496668338776 0.6748451590538025 3.401656150817871
+1.9937467575073242 0.0019821154419332743 0.029486896470189095 1.0089606046676636 3.6359665393829346
+1.9949076175689697 0.009814371354877949 0.040217045694589615 1.277014970779419 2.8420732021331787
+2.003324270248413 0.00764129962772131 0.03658125922083855 1.328791856765747 2.7779572010040283
+2.0792109966278076 0.0020989100448787212 0.027491716668009758 0.7560747861862183 4.12162971496582
+2.041865348815918 0.0019298437982797623 0.03238606080412865 1.1370152235031128 3.134122848510742
+1.9778121709823608 0.002405273262411356 0.04360933229327202 1.5868195295333862 2.3765718936920166
+2.0281569957733154 0.002642302308231592 0.030078817158937454 0.9341065883636475 3.786850690841675
+1.9425616264343262 0.0021410491317510605 0.03306722268462181 1.058952808380127 3.4559402465820312
+2.0371761322021484 0.0031686322763562202 0.05191976577043533 1.4021738767623901 2.8882853984832764
+1.9572911262512207 0.002238825662061572 0.0400988906621933 1.291896939277649 2.8749806880950928
+1.9159085750579834 0.005921671167016029 0.030989734455943108 0.6951913237571716 3.4337799549102783
+1.9576499462127686 0.0023293057456612587 0.030078036710619926 0.77984619140625 3.5316455364227295
+1.9310557842254639 0.007157730869948864 0.04426664113998413 1.3611301183700562 2.6592557430267334
+1.8699653148651123 0.006084154359996319 0.03476577252149582 0.9577652215957642 2.8835015296936035
+1.9618545770645142 0.0018447460606694221 0.026722468435764313 0.7379636168479919 3.526209831237793
+1.8907475471496582 0.0018831726629287004 0.028147712349891663 1.1387768983840942 2.810741901397705
+1.9234790802001953 0.0023971423506736755 0.029613204300403595 1.0536158084869385 2.9649345874786377
+1.9443973302841187 0.004599445499479771 0.027670172974467278 0.8076340556144714 3.448000192642212
+1.8616864681243896 0.0067807817831635475 0.04219236224889755 1.0566561222076416 2.7495620250701904
+1.8916336297988892 0.002328297821804881 0.034823935478925705 1.1156176328659058 2.853585958480835
+1.8801867961883545 0.007434742525219917 0.027508212253451347 0.7880774140357971 3.4444329738616943
+1.9571572542190552 0.004219583235681057 0.032496146857738495 0.9097768664360046 3.3028762340545654
+1.9342483282089233 0.0031784852035343647 0.055589236319065094 1.5638427734375 2.4997787475585938
+1.7599384784698486 0.0026342272758483887 0.04081159830093384 0.8037471175193787 3.0582921504974365
+1.9688005447387695 0.004108640365302563 0.02784525603055954 0.8067436218261719 3.6171302795410156
+1.938948631286621 0.0022105544339865446 0.03781194984912872 1.4042927026748657 2.639395236968994
+1.8722219467163086 0.002146013779565692 0.031190501525998116 1.2604002952575684 2.7625999450683594
+1.7660212516784668 0.00227856426499784 0.024572810158133507 0.6584057211875916 3.109064817428589
+1.8786160945892334 0.005020801443606615 0.026389526203274727 0.6811804175376892 3.2945449352264404
+1.9208781719207764 0.009073230437934399 0.03695060685276985 1.0035523176193237 2.604179859161377
+1.90842604637146 0.002709644380956888 0.02752816490828991 0.9361498355865479 2.927164077758789
+1.8071556091308594 0.0030387730803340673 0.02828061953186989 0.698761522769928 3.2616231441497803
+2.019239902496338 0.002229085424914956 0.029636573046445847 0.8905763030052185 3.5728816986083984
+1.931828260421753 0.006299061235040426 0.0612684041261673 1.5666842460632324 2.3192312717437744
+1.9559845924377441 0.002479834482073784 0.0343015231192112 0.866552472114563 3.373225212097168
+2.0246405601501465 0.005263795610517263 0.031666114926338196 0.7769514322280884 3.794995069503784
+1.9038729667663574 0.0028777318075299263 0.05159424617886543 1.292833924293518 2.472811460494995
+1.8494772911071777 0.003144338959828019 0.04372289404273033 1.1878979206085205 2.6609745025634766
+1.888790488243103 0.005249352660030127 0.029787704348564148 0.7515946626663208 3.47794771194458
+1.8580412864685059 0.0020494465716183186 0.031207747757434845 1.0533760786056519 3.246600866317749
+1.8294658660888672 0.004555483814328909 0.04498486965894699 1.299331545829773 2.3330981731414795
+1.9238754510879517 0.00283889751881361 0.026244936510920525 1.037253975868225 3.22832989692688
+1.8125200271606445 0.0025972253642976284 0.02590847574174404 0.7785575985908508 3.1601333618164062
+1.7460360527038574 0.006463534664362669 0.02566099725663662 0.9274894595146179 2.7167587280273438
+1.851065754890442 0.00836748257279396 0.02673536352813244 0.9250126481056213 2.746307373046875
+1.8449615240097046 0.0041245361790061 0.025775212794542313 0.8258721828460693 3.001760244369507
+1.8685448169708252 0.001978485845029354 0.027874113991856575 0.7528937458992004 2.933825731277466
+1.9551661014556885 0.0028069119434803724 0.03142548352479935 0.8190605640411377 3.2364625930786133
+1.8273327350616455 0.002949288347736001 0.04626196622848511 1.1530649662017822 2.657485008239746
+1.858546495437622 0.004415399394929409 0.03731376677751541 0.9209600687026978 3.095802068710327
+1.7974506616592407 0.00475727254524827 0.037363309413194656 0.7563433647155762 3.18721342086792
+1.7675838470458984 0.00333989504724741 0.04898197576403618 1.2160282135009766 2.7353360652923584
+1.8716566562652588 0.003574054455384612 0.04506317153573036 1.4386756420135498 2.5049145221710205
+1.7601908445358276 0.006463455967605114 0.028063183650374413 0.769799530506134 3.0549161434173584
+1.756582260131836 0.0022983383387327194 0.02571430802345276 0.7382007241249084 3.189361572265625
+1.7746270895004272 0.009652533568441868 0.03374391421675682 0.9818967580795288 2.560930013656616
+1.828487753868103 0.010075879283249378 0.03206643834710121 1.0832138061523438 2.4522616863250732
+1.7364312410354614 0.002004846464842558 0.024447226896882057 0.7533653974533081 3.2161130905151367
+1.7968462705612183 0.0019644068088382483 0.0249209962785244 0.7833787202835083 3.0589118003845215
+1.8747575283050537 0.005308866035193205 0.037346456199884415 1.3853086233139038 2.212808609008789
+1.7871067523956299 0.0029377159662544727 0.030481308698654175 1.0134754180908203 2.8115220069885254
+1.9534783363342285 0.0021579160820692778 0.024850735440850258 0.6522034406661987 3.799016237258911
+1.8365525007247925 0.0026351092383265495 0.04416202753782272 1.1214208602905273 2.5318522453308105
+1.8916471004486084 0.002545880153775215 0.044891368597745895 1.2677315473556519 2.518810510635376
+1.8274511098861694 0.00851800013333559 0.028734011575579643 0.6910852193832397 3.2087955474853516
+1.825024127960205 0.004381425678730011 0.03159096837043762 0.7406511306762695 3.35080885887146
+1.8520352840423584 0.004079404287040234 0.04589667543768883 1.0398612022399902 2.6687138080596924
+1.8514398336410522 0.00389539310708642 0.04612501710653305 1.1242481470108032 2.2894792556762695
+1.9315650463104248 0.0032960183452814817 0.02719350904226303 0.586176872253418 3.7368781566619873
+1.6671743392944336 0.0018754999618977308 0.027841534465551376 0.6671903133392334 2.9336636066436768
+1.7906394004821777 0.006974408868700266 0.038969073444604874 1.3806313276290894 2.3122031688690186
+1.815952181816101 0.004126595798879862 0.02900015190243721 0.9440738558769226 2.8047127723693848
+1.808140754699707 0.0019580558873713017 0.04008358344435692 0.8490013480186462 3.1647512912750244
+1.743436574935913 0.0019865829963237047 0.027266625314950943 0.8366134762763977 2.8958117961883545
+1.8802653551101685 0.002893825527280569 0.032858721911907196 1.2120434045791626 2.7266767024993896
+1.7639474868774414 0.002432175213471055 0.03191661089658737 0.9537951946258545 2.6871488094329834
+1.8189555406570435 0.002403518883511424 0.027730008587241173 0.7969512939453125 3.080704689025879
+1.7192885875701904 0.0023234442342072725 0.0292398389428854 0.7339355945587158 2.900405168533325
+1.6476863622665405 0.0025222657714039087 0.037165675312280655 0.895119309425354 2.5184545516967773
+1.7644531726837158 0.0022232162300497293 0.03829800337553024 1.1246687173843384 2.4722726345062256
+1.7683138847351074 0.003578806295990944 0.02774023450911045 0.7887780070304871 2.988482713699341
+1.7623869180679321 0.002514907158911228 0.026938652619719505 0.902067244052887 2.8820643424987793
+1.7793165445327759 0.0028413215186446905 0.03019307367503643 0.6418291330337524 2.863642454147339
+1.7791961431503296 0.00450071832165122 0.03473204746842384 0.8882821202278137 2.631948232650757
+1.703615427017212 0.001994737656787038 0.028719881549477577 0.7330031394958496 2.881582498550415
+1.7195332050323486 0.002006658585742116 0.02834823727607727 0.7858183979988098 2.8577206134796143
+1.8660914897918701 0.0018155163852497935 0.05890649929642677 1.2111340761184692 2.5272700786590576
+1.769626498222351 0.0019840432796627283 0.030070457607507706 0.9213112592697144 2.93245005607605
+1.8130083084106445 0.0027843602001667023 0.027927303686738014 0.827678382396698 2.9210124015808105
+1.7655267715454102 0.0030163023620843887 0.033441200852394104 0.9265094995498657 2.759603261947632
+1.71917724609375 0.0019972913432866335 0.03151067718863487 0.8928638100624084 2.6698856353759766
+1.774315595626831 0.0029596283566206694 0.041407063603401184 0.7373692393302917 2.975398540496826
+1.7110464572906494 0.0019381912425160408 0.028257906436920166 0.8433794975280762 2.7149951457977295
+1.6652781963348389 0.0025810443330556154 0.03168303892016411 0.7547350525856018 2.6733932495117188
+1.7694649696350098 0.003393034217879176 0.03113435208797455 0.8062439560890198 2.634939670562744
+1.7438088655471802 0.0019146197009831667 0.02679058164358139 0.9362980127334595 2.8636157512664795
+1.7697199583053589 0.0019191975006833673 0.02767126075923443 0.848552942276001 2.592188835144043
+1.807201623916626 0.0019319010898470879 0.026454858481884003 0.7985832095146179 2.925593376159668
+1.6515105962753296 0.0023326280061155558 0.030937617644667625 0.8249584436416626 2.428455114364624
+1.6708550453186035 0.0022507954854518175 0.028619175776839256 0.8659945726394653 2.766634941101074
+1.6434690952301025 0.0020312091801315546 0.027686506509780884 0.7795204520225525 2.7932469844818115
+1.7953310012817383 0.0019376102136448026 0.03216862305998802 1.0066049098968506 2.6951022148132324
+1.7021180391311646 0.0025275317020714283 0.03053789772093296 0.915795624256134 2.4286129474639893
+1.8147774934768677 0.0030340144876390696 0.024249324575066566 0.7689998149871826 3.1810553073883057
+1.6752132177352905 0.0024185078218579292 0.029087703675031662 1.0061771869659424 2.46974515914917
+1.6765730381011963 0.003534330753609538 0.02962130308151245 0.7138554453849792 2.5299911499023438
+1.6801425218582153 0.0020997568499296904 0.02586856484413147 0.5968827605247498 3.0951273441314697
+1.7125053405761719 0.0021090505179017782 0.029699288308620453 0.8460505604743958 2.782806158065796
+1.7624272108078003 0.0018843680154532194 0.0399935357272625 1.2561802864074707 2.1936941146850586
+1.7307754755020142 0.004299585707485676 0.02745119109749794 0.8226410150527954 2.89115834236145
+1.7436397075653076 0.0021927193738520145 0.025432968512177467 0.50652676820755 3.2032854557037354
+1.840078353881836 0.005367416422814131 0.04011597856879234 1.2234742641448975 2.2603681087493896
+1.6446616649627686 0.0029341652989387512 0.02836672216653824 0.8884162306785583 2.4445433616638184
+1.73862624168396 0.00231199711561203 0.02424340508878231 0.503403902053833 3.35058856010437
+1.6509428024291992 0.0018432977376505733 0.028327645733952522 0.7402868270874023 2.550119400024414
+1.6220569610595703 0.0034288696479052305 0.03338835388422012 1.064284324645996 2.217589855194092
+1.7020009756088257 0.0018163464264944196 0.025452766567468643 0.7306921482086182 2.843520402908325
+1.6860580444335938 0.0021717133931815624 0.025617776438593864 0.6387526988983154 2.918311595916748
+1.5533158779144287 0.001996680861338973 0.03072119876742363 0.9059006571769714 2.2664692401885986
+1.5998351573944092 0.0017654522089287639 0.028702275827527046 0.9555583596229553 2.2348592281341553
+1.6179159879684448 0.0017811141442507505 0.02249964140355587 0.5187113285064697 2.8572072982788086
+1.6606273651123047 0.0022118366323411465 0.024519354104995728 0.6932038068771362 2.7853844165802
+1.681624412536621 0.004715753253549337 0.03162931278347969 0.8490603566169739 2.2675461769104004
+1.5923364162445068 0.0020578838884830475 0.02457672916352749 0.6494108438491821 2.8395183086395264
+1.7724590301513672 0.0019530373392626643 0.02585209533572197 0.5962873697280884 3.147813558578491
+1.707790732383728 0.002009447431191802 0.03985188156366348 1.1515867710113525 2.0127391815185547
+1.5486094951629639 0.005022284109145403 0.02841857261955738 0.6368111371994019 2.6445744037628174
+1.8010151386260986 0.0038235385436564684 0.02622898668050766 0.4167150557041168 3.5709502696990967
+1.6437568664550781 0.0033846080768853426 0.049741536378860474 1.2916483879089355 1.9251060485839844
+1.5979325771331787 0.0016808087239041924 0.03662082552909851 1.0702393054962158 1.965878963470459
+1.7672085762023926 0.006729196757078171 0.03178185224533081 0.4187065064907074 3.5710089206695557
+1.6546375751495361 0.0020565579179674387 0.024083968251943588 0.5423476099967957 2.8026492595672607
+1.6635441780090332 0.016221096739172935 0.04111099988222122 1.1787365674972534 1.9787616729736328
+1.6204372644424438 0.00429843133315444 0.023811785504221916 0.7878445982933044 2.5839784145355225
+1.6803598403930664 0.0026732520200312138 0.02974826470017433 0.6497639417648315 3.0439321994781494
+1.661489486694336 0.0016439775936305523 0.023591527715325356 0.8086404800415039 2.6207730770111084
+1.6361943483352661 0.008005491457879543 0.03338583931326866 1.0275479555130005 2.038055896759033
+1.6740620136260986 0.004482485353946686 0.025562774389982224 0.6728584170341492 2.8537347316741943
+1.6941198110580444 0.0020826158579438925 0.02365870214998722 0.5392012596130371 3.1379411220550537
+1.6130008697509766 0.0022273766808211803 0.03507784754037857 1.1066324710845947 2.074223756790161
+1.6249182224273682 0.004734295420348644 0.031796831637620926 0.9225327372550964 2.0132899284362793
+1.5875961780548096 0.009031221270561218 0.02321208268404007 0.5071538686752319 3.0601203441619873
+1.5013083219528198 0.002968432381749153 0.029959235340356827 0.6022948622703552 2.7006521224975586
+1.652659296989441 0.010574538260698318 0.061274588108062744 1.1311957836151123 2.033165693283081
+1.6172635555267334 0.0023847827687859535 0.03113020956516266 0.7866807579994202 2.3763372898101807
+1.664696455001831 0.015765506774187088 0.027798114344477654 0.5335696935653687 2.989097833633423
+1.6098711490631104 0.003943817224353552 0.023688193410634995 0.6862239241600037 2.599247455596924
+1.6934967041015625 0.007889047265052795 0.037090837955474854 1.0400617122650146 2.1181063652038574
+1.5681815147399902 0.008757095783948898 0.026724552735686302 0.8361263871192932 2.523447036743164
+1.6442325115203857 0.0022598085924983025 0.02080416865646839 0.43438318371772766 3.1490559577941895
+1.5270907878875732 0.0018144558416679502 0.030596258118748665 0.7618647813796997 2.029353141784668
+1.5337923765182495 0.0020800302736461163 0.02551697939634323 0.8625068664550781 2.144132137298584
+1.6338114738464355 0.0029189891647547483 0.02374143898487091 0.611099362373352 2.886512041091919
+1.6681456565856934 0.0019735221285372972 0.026611968874931335 0.7310472726821899 2.8402061462402344
+1.5848050117492676 0.004448621533811092 0.03875787556171417 0.8863359689712524 2.2854392528533936
+1.5879292488098145 0.001807519467547536 0.031115228310227394 0.8078795671463013 2.2355384826660156
+1.5886001586914062 0.00755371805280447 0.025778116658329964 0.5709618926048279 2.689502239227295
+1.5401535034179688 0.0028799446299672127 0.024795876815915108 0.5997941493988037 2.6699154376983643
+1.5924749374389648 0.006000976078212261 0.0325491726398468 0.8872876763343811 2.1968436241149902
+1.585385799407959 0.00701961200684309 0.029836976900696754 0.8882437348365784 2.263505458831787
+1.5687768459320068 0.0016654065111652017 0.022200433537364006 0.5030658841133118 2.803684711456299
+1.6484248638153076 0.0014793716836720705 0.02358134463429451 0.6250075101852417 2.677720785140991
+1.531619668006897 0.002075096359476447 0.03243834525346756 0.7722308039665222 2.037306547164917
+1.7211099863052368 0.002025413326919079 0.026521971449255943 0.7166793346405029 2.681478261947632
+1.4957356452941895 0.00180937466211617 0.023802384734153748 0.48264798521995544 2.4888150691986084
+1.5507328510284424 0.0019795747939497232 0.02584318444132805 0.5844953656196594 2.634390115737915
+1.5991352796554565 0.0020927300211042166 0.03095933049917221 0.8761265873908997 2.34399151802063
+1.5677931308746338 0.0023453691974282265 0.028346950188279152 0.9186573624610901 2.2265570163726807
+1.4844212532043457 0.004252808168530464 0.024029109627008438 0.7170466184616089 2.3594107627868652
+1.6357700824737549 0.0019182938849553466 0.02080165408551693 0.5381293892860413 2.8625564575195312
+1.5998573303222656 0.008044376969337463 0.030347703024744987 0.5876179337501526 2.433506965637207
+1.4930893182754517 0.007627176586538553 0.03224834054708481 0.6191920042037964 2.1030664443969727
+1.567882776260376 0.0021626888774335384 0.023232363164424896 0.6002436876296997 2.547755718231201
+1.5538220405578613 0.005552292335778475 0.02315150573849678 0.6170064210891724 2.6097331047058105
+1.5064629316329956 0.0021922856103628874 0.02879434823989868 0.9342961311340332 2.263766050338745
+1.534204125404358 0.002840630942955613 0.03775767982006073 0.8884289264678955 2.0996696949005127
+1.5345098972320557 0.002251196885481477 0.026037534698843956 0.5906670689582825 2.4353411197662354
+1.5009868144989014 0.003948407247662544 0.023631500080227852 0.5460475087165833 2.530855417251587
+1.5448225736618042 0.0018688385607674718 0.023470180109143257 0.717455267906189 2.357680559158325
+1.5737934112548828 0.002877993043512106 0.02436232939362526 0.778223991394043 2.2658841609954834
+1.633165955543518 0.003505677916109562 0.021503625437617302 0.5827279090881348 2.8290722370147705
+1.5147998332977295 0.003279800759628415 0.021797331050038338 0.5803824067115784 2.5465545654296875
+1.5589613914489746 0.001880417694337666 0.02640661410987377 0.7866075038909912 2.0355207920074463
+1.5196443796157837 0.002838206710293889 0.02480332925915718 0.6419928669929504 2.3651163578033447
+1.5819350481033325 0.00217369943857193 0.023493684828281403 0.6285433769226074 2.696187973022461
+1.5439300537109375 0.004536671098321676 0.02649262361228466 0.6645966172218323 2.3491005897521973
+1.6166118383407593 0.007128401659429073 0.031042413786053658 0.9639424085617065 2.2478275299072266
+1.5559148788452148 0.0015936321578919888 0.021935831755399704 0.6548852920532227 2.6171698570251465
+1.5399515628814697 0.0030489806085824966 0.02266249805688858 0.5896272659301758 2.5532891750335693
+1.520635962486267 0.0017131709028035402 0.02526264823973179 0.6865273118019104 2.147575855255127
+1.615986704826355 0.0023933621123433113 0.02492741122841835 0.5971196293830872 2.4216606616973877
+1.549262523651123 0.0023258007131516933 0.023324614390730858 0.4202658236026764 2.8445160388946533
+1.4448758363723755 0.0023036987986415625 0.029573431238532066 0.73952317237854 2.2101266384124756
+1.554898738861084 0.002881661057472229 0.027446556836366653 0.9359902739524841 2.12724232673645
+1.5278980731964111 0.007975149899721146 0.024302832782268524 0.6161779761314392 2.6362640857696533
+1.5512547492980957 0.0019576416816562414 0.021858282387256622 0.6336979269981384 2.6735308170318604
+1.5277855396270752 0.007491168566048145 0.03373105078935623 0.788689374923706 2.191002368927002
+1.5518704652786255 0.005305217579007149 0.027881000190973282 0.6957305669784546 2.256666421890259
+1.4592905044555664 0.0022186075802892447 0.021283600479364395 0.6840808987617493 2.18583607673645
+1.4711554050445557 0.003951045218855143 0.023154377937316895 0.5256173014640808 2.645843982696533
+1.4608839750289917 0.00432351091876626 0.02311660163104534 0.6055504679679871 2.2306711673736572
+1.4834394454956055 0.011967668309807777 0.03307211026549339 0.9923378825187683 1.869388222694397
+1.4620051383972168 0.0021932083182036877 0.018700379878282547 0.49639225006103516 2.4697868824005127
+1.545432209968567 0.00344279408454895 0.020294444635510445 0.5663896203041077 2.6541149616241455
+1.5313199758529663 0.00217650830745697 0.026524055749177933 0.7098141312599182 2.2419848442077637
+1.5221424102783203 0.0020802391227334738 0.03518616780638695 0.8918864130973816 1.9300715923309326
+1.5098072290420532 0.0032850077841430902 0.022564643993973732 0.5029506683349609 2.7530746459960938
+1.4613884687423706 0.0023450886365026236 0.024953678250312805 0.3528076112270355 2.7787833213806152
+1.5167418718338013 0.002500183880329132 0.045953359454870224 1.1681060791015625 1.6822788715362549
+1.4497008323669434 0.0032434859313070774 0.025726711377501488 0.7957074046134949 2.0176451206207275
+1.6054693460464478 0.006569023244082928 0.02577388659119606 0.438433974981308 3.1933798789978027
+1.4916012287139893 0.004199076909571886 0.023064762353897095 0.5979934334754944 2.5897388458251953
+1.516204595565796 0.017489595338702202 0.05040903016924858 1.0880463123321533 1.4965810775756836
+1.4144501686096191 0.0016200561076402664 0.019800741225481033 0.5893478989601135 2.2227563858032227
+1.5018740892410278 0.006308173760771751 0.032428860664367676 0.43205752968788147 2.967088460922241
+1.4862715005874634 0.0026180273853242397 0.033681102097034454 0.6774531602859497 2.27602481842041
+1.4927020072937012 0.016382863745093346 0.04296713322401047 1.0165561437606812 1.7008590698242188
+1.372400164604187 0.0034521508496254683 0.02398713491857052 0.5661884546279907 2.264939308166504
+1.5323944091796875 0.006335897371172905 0.02226576767861843 0.42932531237602234 2.8618829250335693
+1.416046142578125 0.003813577815890312 0.02608378231525421 0.6538215279579163 2.067291736602783
+1.4917913675308228 0.0020107994787395 0.03887663036584854 1.0176340341567993 1.63876211643219
+1.4500772953033447 0.002637389814481139 0.02395239844918251 0.44114649295806885 2.514686346054077
+1.469178318977356 0.0021058106794953346 0.020051755011081696 0.36244747042655945 2.757463216781616
+1.3983378410339355 0.004486990161240101 0.032780084758996964 0.7178557515144348 2.1145472526550293
+1.5065361261367798 0.0030534989200532436 0.03766199201345444 1.0321664810180664 1.663148045539856
+1.4253695011138916 0.00594277773052454 0.023110123351216316 0.5073233842849731 2.4078426361083984
+1.5291093587875366 0.005473477300256491 0.023503558710217476 0.4678640067577362 2.8911144733428955
+1.509963035583496 0.005917701404541731 0.029190391302108765 0.6738882660865784 1.8517341613769531
+1.3439645767211914 0.010532020591199398 0.033677272498607635 0.7232192754745483 1.708952784538269
+1.5487637519836426 0.0015091997338458896 0.019017649814486504 0.3791029751300812 2.999760866165161
+1.4276536703109741 0.0019211710896342993 0.021608324721455574 0.48042672872543335 2.31540584564209
+1.5116939544677734 0.001620383933186531 0.025728199630975723 0.7594985365867615 1.8007831573486328
+1.3709053993225098 0.001590896979905665 0.021929241716861725 0.5442606806755066 2.109208822250366
+1.460768699645996 0.0015917522832751274 0.019313964992761612 0.41998419165611267 2.8257200717926025
+1.465391993522644 0.004734575282782316 0.025909727439284325 0.75806725025177 2.1479332447052
+1.438415288925171 0.002680819947272539 0.03259716182947159 0.760665774345398 1.7997052669525146
+1.4404828548431396 0.004386692773550749 0.02224084362387657 0.4259363114833832 2.4182753562927246
+1.507951259613037 0.002853234065696597 0.020005768164992332 0.34798508882522583 2.8301587104797363
+1.4043021202087402 0.006922861095517874 0.039259638637304306 0.6498057246208191 2.0992188453674316
+1.4037957191467285 0.0054142773151397705 0.04369558021426201 0.9749426245689392 1.6135234832763672
+1.4938654899597168 0.007117066066712141 0.024397660046815872 0.5988961458206177 2.534287691116333
+1.4412970542907715 0.011368420906364918 0.02753402665257454 0.45355290174484253 2.711102247238159
+1.360907793045044 0.002597192069515586 0.02934143878519535 0.7299192547798157 1.8414074182510376
+1.342655897140503 0.010151594877243042 0.036935847252607346 0.9073513746261597 1.7266556024551392
+1.3419811725616455 0.0020427308045327663 0.01878262683749199 0.53898024559021 2.221383810043335
+1.4020419120788574 0.002933304524049163 0.02127227559685707 0.406158983707428 2.5379185676574707
+1.3508846759796143 0.001607811776921153 0.01970466412603855 0.6600982546806335 1.8834612369537354
+1.4826375246047974 0.0037905338685959578 0.02350696548819542 0.6562049984931946 1.9214096069335938
+1.4600002765655518 0.0019502539653331041 0.017657384276390076 0.372648149728775 2.6424944400787354
+1.3959673643112183 0.0019270837074145675 0.01934542879462242 0.36896002292633057 2.4621543884277344
+1.4483001232147217 0.002227515447884798 0.02950860932469368 0.8014217019081116 1.8531185388565063
+1.4356954097747803 0.002544906921684742 0.0243168193846941 0.8591676950454712 1.9433022737503052
+1.4121623039245605 0.00627444963902235 0.022361036390066147 0.539585292339325 2.44596791267395
+1.335897445678711 0.0018807266606017947 0.019936345517635345 0.45032936334609985 2.417398452758789
+1.4242178201675415 0.007308757398277521 0.033311858773231506 0.7002217769622803 1.9494727849960327
+1.4077022075653076 0.0037225810810923576 0.026528548449277878 0.7133819460868835 1.7077399492263794
+1.3765480518341064 0.0036730156280100346 0.02287949249148369 0.5299034118652344 2.3099732398986816
+1.3270083665847778 0.0021135699935257435 0.020171426236629486 0.38500770926475525 2.4664981365203857
+1.3756890296936035 0.011088363826274872 0.027963537722826004 0.6775361895561218 1.971057415008545
+1.4688729047775269 0.01444588415324688 0.040021296590566635 0.7971241474151611 1.604148268699646
+1.4133515357971191 0.0022830346133559942 0.01844867877662182 0.37593409419059753 2.7060718536376953
+1.4195642471313477 0.00846288911998272 0.022123074159026146 0.4007378816604614 2.421358108520508
+1.42732834815979 0.0030203009955585003 0.0286372359842062 0.8021621108055115 1.7423312664031982
+1.2955858707427979 0.001965534407645464 0.02911175973713398 0.6252163052558899 1.8162111043930054
+1.425700068473816 0.0021293743047863245 0.019895169883966446 0.4487748444080353 2.5606441497802734
+1.3991458415985107 0.0017082956619560719 0.019512666389346123 0.5100040435791016 2.3643739223480225
+1.2786465883255005 0.0028696958906948566 0.023830384016036987 0.6884375810623169 1.8885852098464966
+1.3931481838226318 0.002059646649286151 0.021770885214209557 0.7941434979438782 1.7227054834365845
+1.2733204364776611 0.0018442481523379683 0.018874434754252434 0.42629194259643555 2.229574680328369
+1.3783881664276123 0.0018581354524940252 0.016186634078621864 0.5030026435852051 2.3419978618621826
+1.311825156211853 0.009243112057447433 0.023750020191073418 0.5758013725280762 1.8685277700424194
+1.377812147140503 0.005885463673621416 0.0392257384955883 0.6486300826072693 1.9786179065704346
+1.3182874917984009 0.002288827672600746 0.01722407527267933 0.5036587715148926 2.198875665664673
+1.4056527614593506 0.0054882788099348545 0.019844835624098778 0.3873608410358429 2.338376998901367
+1.3017688989639282 0.002231214428320527 0.025117618963122368 0.475246787071228 1.8822304010391235
+1.3705942630767822 0.0021821276750415564 0.03286977484822273 0.6583255529403687 1.9163609743118286
+1.3169182538986206 0.0027706269174814224 0.02711670659482479 0.3864171504974365 2.141664981842041
+1.2938010692596436 0.0044628637842834 0.021435819566249847 0.5495985746383667 2.000279188156128
+1.3939651250839233 0.0026573953218758106 0.02285653166472912 0.5256578326225281 2.195194959640503
+1.3845407962799072 0.0018602468771860003 0.021743759512901306 0.5241093635559082 1.9903432130813599
+1.2985491752624512 0.0034153738524764776 0.02022913657128811 0.499900221824646 2.059868097305298
+1.3600353002548218 0.0037418652791529894 0.017394619062542915 0.5655961632728577 2.266740083694458
+1.335513949394226 0.0035634790547192097 0.019619261845946312 0.6359632611274719 1.8371456861495972
+1.3760499954223633 0.001814109506085515 0.019004404544830322 0.5363307595252991 2.178332567214966
+1.3433737754821777 0.0014641741290688515 0.01776985451579094 0.542802095413208 2.129265069961548
+1.3717279434204102 0.0019639802630990744 0.018999967724084854 0.6696678996086121 1.884041428565979
+1.3428782224655151 0.001897971611469984 0.018046142533421516 0.5302223563194275 2.198833465576172
+1.3532605171203613 0.0017753473948687315 0.019283859059214592 0.4694671332836151 2.1233108043670654
+1.3392704725265503 0.0019638966768980026 0.02243809588253498 0.5800414681434631 2.1212799549102783
+1.3872172832489014 0.0027179422322660685 0.021227344870567322 0.49187350273132324 2.244131326675415
+1.3215305805206299 0.0033289764542132616 0.022875873371958733 0.5722663998603821 1.963051438331604
+1.24066960811615 0.0032342677004635334 0.021207068115472794 0.5986080765724182 1.7761797904968262
+1.3583751916885376 0.002231067745015025 0.02106638066470623 0.5024628043174744 2.0224990844726562
+1.3142950534820557 0.0021545353811234236 0.016902592033147812 0.43064525723457336 2.156162977218628
+1.2422151565551758 0.006667181383818388 0.019720520824193954 0.3623657822608948 2.0652639865875244
+1.2927197217941284 0.0060428436845541 0.021283674985170364 0.5734431147575378 1.8480703830718994
+1.326313853263855 0.0016295048408210278 0.017635490745306015 0.5539844632148743 2.0279383659362793
+1.3332173824310303 0.0027785608544945717 0.0181503277271986 0.5735659003257751 2.0447680950164795
+1.3676731586456299 0.001770783681422472 0.01811392232775688 0.49404177069664 2.1170315742492676
+1.3432421684265137 0.00420021265745163 0.02558310143649578 0.5658242702484131 2.244802951812744
+1.259386658668518 0.003313510213047266 0.029421737417578697 0.5856508612632751 1.8896589279174805
+1.285213828086853 0.0036468582693487406 0.022091634571552277 0.5817267298698425 1.9242134094238281
+1.350968837738037 0.011484108865261078 0.020933479070663452 0.4446621239185333 2.080681562423706
+1.3031491041183472 0.003339091781526804 0.018286554142832756 0.4651145040988922 2.0468902587890625
+1.3013694286346436 0.005208133719861507 0.023121926933526993 0.5562808513641357 2.010730743408203
+1.2593939304351807 0.00911681354045868 0.023314503952860832 0.4811881184577942 1.8568195104599
+1.2608445882797241 0.002083249855786562 0.014701131731271744 0.40303486585617065 2.0499773025512695
+1.3298163414001465 0.001894541084766388 0.017237383872270584 0.492674320936203 2.0187270641326904
+1.3339745998382568 0.0015826788730919361 0.017969803884625435 0.4712804853916168 1.872422695159912
+1.2800997495651245 0.0019211848266422749 0.019186966121196747 0.5540816187858582 2.0144293308258057
+1.3112261295318604 0.00233906670473516 0.020719945430755615 0.5166690349578857 2.0798146724700928
+1.329734206199646 0.0015376140363514423 0.02196342498064041 0.5424560308456421 1.9683964252471924
+1.2517635822296143 0.002741579432040453 0.019074220210313797 0.4628356099128723 2.0538289546966553
+1.3121442794799805 0.002270522527396679 0.01982511393725872 0.4314822852611542 2.079982042312622
+1.2617101669311523 0.0015183842042461038 0.023822786286473274 0.5900253653526306 1.7089719772338867
+1.2262296676635742 0.0014702756889164448 0.01922699436545372 0.5090339779853821 2.00921368598938
+1.2044987678527832 0.0014363751979544759 0.017050541937351227 0.41865286231040955 1.9993877410888672
+1.2582436800003052 0.0022226099390536547 0.016812944784760475 0.523521363735199 2.0176889896392822
+1.3126108646392822 0.0020409105345606804 0.01982523687183857 0.6960505247116089 1.7939566373825073
+1.246585726737976 0.0015465429751202464 0.018041938543319702 0.6348589658737183 1.7415000200271606
+1.2059098482131958 0.0018843661528080702 0.017481697723269463 0.4214847981929779 2.0569801330566406
+1.2586725950241089 0.0025164096150547266 0.016095703467726707 0.45144304633140564 2.0057156085968018
+1.3160228729248047 0.007054970134049654 0.024467485025525093 0.6019233465194702 1.7976053953170776
+1.2175517082214355 0.0019464025972411036 0.019955765455961227 0.4949692189693451 1.820764183998108
+1.3106372356414795 0.0049776299856603146 0.01920405589044094 0.4803968071937561 2.133071184158325
+1.2721747159957886 0.003994040656834841 0.01885392516851425 0.4214286804199219 2.029841899871826
+1.2417733669281006 0.0019999719224870205 0.025529786944389343 0.5647101998329163 1.7025078535079956
+1.1919862031936646 0.0028300576377660036 0.024854101240634918 0.5588966012001038 1.6363914012908936
+1.2921249866485596 0.001545894076116383 0.014439569786190987 0.3264389634132385 2.355123996734619
+1.2994134426116943 0.0016803916078060865 0.01835191436111927 0.4559169113636017 2.1278302669525146
+1.274455189704895 0.002365114400163293 0.03040360100567341 0.6966040134429932 1.6022592782974243
+1.1862781047821045 0.002418078249320388 0.019614188000559807 0.5042970776557922 1.7220507860183716
+1.2521576881408691 0.004783857148140669 0.01938166655600071 0.3461814820766449 2.3111233711242676
+1.2027382850646973 0.0018694804748520255 0.02179354429244995 0.5276983976364136 1.9236316680908203
+1.3347034454345703 0.00550390500575304 0.03135155886411667 0.7705752849578857 1.5881620645523071
+1.2436155080795288 0.001433874131180346 0.0178952869027853 0.5780288577079773 1.8747695684432983
+1.3459649085998535 0.004495969042181969 0.020583990961313248 0.3486032485961914 2.4449825286865234
+1.2215625047683716 0.0015668428968638182 0.01762852631509304 0.4837683141231537 1.8232117891311646
+1.2175623178482056 0.006894037127494812 0.02244778536260128 0.6090457439422607 1.6491469144821167
+1.2758986949920654 0.00266548921354115 0.017670007422566414 0.6028624773025513 1.8862780332565308
+1.3022515773773193 0.0019703248981386423 0.017296886071562767 0.37619394063949585 2.373502016067505
+1.2210171222686768 0.001454621204175055 0.015656501054763794 0.5852521657943726 1.9056278467178345
+1.2222809791564941 0.00374540570192039 0.022240718826651573 0.6511134505271912 1.5263302326202393
+1.2554407119750977 0.0019126457627862692 0.01783117651939392 0.4193454384803772 2.015009880065918
+1.2181838750839233 0.001874834531918168 0.01456672977656126 0.444507896900177 2.034656524658203
+1.1645526885986328 0.001896661357022822 0.01732095703482628 0.574618399143219 1.8906587362289429
+1.251265287399292 0.0017859461950138211 0.021371470764279366 0.6260912418365479 1.5785174369812012
+1.2333053350448608 0.0016550719738006592 0.01835063472390175 0.507073163986206 1.7937202453613281
+1.3160767555236816 0.0018307947320863605 0.014617504552006721 0.3835010230541229 2.278336763381958
+1.2294437885284424 0.002538719680160284 0.02092076651751995 0.3836418092250824 1.95942223072052
+1.1806018352508545 0.0034054708667099476 0.030805636197328568 0.48459145426750183 1.5807090997695923
+1.1840522289276123 0.0026172392535954714 0.020194703713059425 0.44352397322654724 1.6586642265319824
+1.2210803031921387 0.008551453240215778 0.017537400126457214 0.3730137050151825 2.1036288738250732
+1.1944994926452637 0.0020795022137463093 0.01699930429458618 0.3984766900539398 2.047349452972412
+1.2726716995239258 0.006144314538687468 0.03237880766391754 0.578228771686554 1.658099889755249
+1.1658213138580322 0.0030478129629045725 0.02289210818707943 0.4436139464378357 1.820284128189087
+1.0936568975448608 0.00286017800681293 0.017286600545048714 0.5043545365333557 1.7970377206802368
+1.3038798570632935 0.003396830055862665 0.018826939165592194 0.5035945773124695 1.9734246730804443
+1.2257561683654785 0.0019896510057151318 0.017664892598986626 0.5603427886962891 1.7126787900924683
+1.2029322385787964 0.00821096170693636 0.02142324298620224 0.5184009075164795 1.779874563217163
+1.2348127365112305 0.005119415931403637 0.019289640709757805 0.494422048330307 2.062389850616455
+1.1884593963623047 0.0013702429132536054 0.019296862185001373 0.752114474773407 1.6972728967666626
+1.1915888786315918 0.003983635455369949 0.019862035289406776 0.5488276481628418 1.5705581903457642
+1.1533169746398926 0.0035318005830049515 0.01854882203042507 0.38140010833740234 1.936329960823059
+1.2292133569717407 0.002071063034236431 0.021942850202322006 0.4656696617603302 2.068932056427002
+1.238728642463684 0.006319214589893818 0.032112251967191696 0.5913751125335693 1.855123519897461
+1.2026355266571045 0.0014504151185974479 0.023559218272566795 0.5349115133285522 1.6467846632003784
+1.2074470520019531 0.008493124507367611 0.020062604919075966 0.38037538528442383 1.8915352821350098
+1.1759639978408813 0.004611613228917122 0.01952214352786541 0.39510229229927063 1.7803407907485962
+1.1668083667755127 0.004019198473542929 0.019851109012961388 0.39237186312675476 1.8101149797439575
+1.1377363204956055 0.012364840134978294 0.02360714226961136 0.5115599036216736 1.7112679481506348
+1.2337543964385986 0.003383696312084794 0.015965303406119347 0.4676302671432495 2.115020990371704
+1.145790934562683 0.002173889195546508 0.016261950135231018 0.5816497206687927 1.548527717590332
+1.2160487174987793 0.005549329333007336 0.01755092479288578 0.6193849444389343 1.5396023988723755
+1.1551368236541748 0.0018461979925632477 0.014367220923304558 0.37266984581947327 1.929283618927002
+1.165732741355896 0.008188236504793167 0.02056356705725193 0.4066096246242523 1.8910969495773315
+1.168285846710205 0.011066203005611897 0.029138730838894844 0.44435447454452515 1.692517876625061
+1.1251916885375977 0.0015923380851745605 0.017927533015608788 0.41022396087646484 1.4701155424118042
+1.211876392364502 0.013942954130470753 0.02013852261006832 0.32107383012771606 1.952459692955017
+1.2041246891021729 0.004458119627088308 0.014996057376265526 0.38962265849113464 2.0206029415130615
+1.1069903373718262 0.0077585517428815365 0.033667225390672684 0.4620846211910248 1.6985254287719727
+1.1865363121032715 0.00802118144929409 0.03995254635810852 0.6619191765785217 1.5095796585083008
+1.202654480934143 0.004555098246783018 0.01750420220196247 0.4787762463092804 1.8761184215545654
+1.1693623065948486 0.013180216774344444 0.02255667932331562 0.3173384666442871 2.1265745162963867
+1.1713149547576904 0.001430536387488246 0.019960841163992882 0.3882737457752228 1.5521571636199951
+1.1159751415252686 0.009297423996031284 0.028800390660762787 0.5683147311210632 1.471577525138855
+1.183121919631958 0.003959158901125193 0.016049083322286606 0.33061158657073975 2.083163261413574
+1.193764090538025 0.0012394518125802279 0.01585996523499489 0.4185401499271393 1.878528118133545
+1.110383152961731 0.00157338532153517 0.01770589128136635 0.5705385208129883 1.5268871784210205
+1.1540799140930176 0.001311195781454444 0.01648051105439663 0.4155767560005188 1.783560872077942
+1.157923936843872 0.0019105600658804178 0.016909027472138405 0.41045182943344116 1.815335750579834
+1.1297321319580078 0.003208283567801118 0.017396697774529457 0.40219587087631226 1.852748990058899
+1.1480560302734375 0.0020531502086669207 0.01836259290575981 0.5137929916381836 1.6781622171401978
+1.1707518100738525 0.00170740089379251 0.01551343034952879 0.5095195174217224 1.690844178199768
+1.129420280456543 0.002954494906589389 0.017132489010691643 0.410351037979126 1.7807121276855469
+1.1820504665374756 0.0013305649627000093 0.016050759702920914 0.45244354009628296 1.9193413257598877
+1.128746747970581 0.003524728585034609 0.02287750504910946 0.45642775297164917 1.7306324243545532
+1.1489092111587524 0.0017310349503532052 0.023143848404288292 0.6581506133079529 1.5646673440933228
+1.0774930715560913 0.004193028900772333 0.015740346163511276 0.40799781680107117 1.7216627597808838
+1.1297285556793213 0.0037977949250489473 0.015711981803178787 0.38615843653678894 1.9691410064697266
+1.1171876192092896 0.0029640397988259792 0.021892044693231583 0.44878387451171875 1.571153998374939
+1.1122958660125732 0.005252352450042963 0.0257948599755764 0.6415137648582458 1.359819769859314
+1.22354257106781 0.0013949234271422029 0.013491405174136162 0.3733060657978058 2.166625738143921
+1.1239148378372192 0.002152259461581707 0.0153264245018363 0.4233725368976593 1.7906757593154907
+1.1356388330459595 0.0014126108726486564 0.01909080147743225 0.49702703952789307 1.514146327972412
+1.0691813230514526 0.0014498995151370764 0.01826469972729683 0.4603979289531708 1.4584028720855713
+1.1693902015686035 0.0015359806129708886 0.014080793596804142 0.3461049795150757 2.1116549968719482
+1.1818945407867432 0.001642131945118308 0.01560641173273325 0.3529036045074463 1.9918525218963623
+1.195223093032837 0.0027583118062466383 0.02127191610634327 0.7814003825187683 1.3754338026046753
+1.1496849060058594 0.001631609397009015 0.016528891399502754 0.530126690864563 1.6599924564361572
+1.1456409692764282 0.003429112024605274 0.01594461314380169 0.3149464428424835 2.1510932445526123
+1.2009193897247314 0.002419513650238514 0.014669091440737247 0.5015804171562195 1.9075783491134644
+1.153981328010559 0.009341198951005936 0.028327258303761482 0.5133489370346069 1.2116830348968506
+1.0729131698608398 0.0015873840311542153 0.013944492675364017 0.26315611600875854 1.7908828258514404
+1.1339459419250488 0.0034474374260753393 0.014376536011695862 0.305418998003006 2.103595018386841
+1.1106929779052734 0.0015350538305938244 0.020250188186764717 0.5555230379104614 1.4758601188659668
+1.0841178894042969 0.0014937638770788908 0.02169850468635559 0.589287281036377 1.2671360969543457
+1.1751511096954346 0.0018974975682795048 0.012909695506095886 0.2693019211292267 2.2846264839172363
+1.0811119079589844 0.001571368076838553 0.018176797777414322 0.34202656149864197 1.7747102975845337
+1.0533722639083862 0.002186815021559596 0.025650687515735626 0.560965895652771 1.1287015676498413
+1.1901156902313232 0.002331182826310396 0.015958404168486595 0.4826425611972809 1.7720086574554443
+1.0765595436096191 0.0027686776593327522 0.01603267341852188 0.2939300537109375 2.027338981628418
+1.157956838607788 0.003619667375460267 0.027089089155197144 0.38117823004722595 1.9188462495803833
+1.2114875316619873 0.007532886229455471 0.04295798018574715 0.6074213981628418 1.3207287788391113
+1.0997756719589233 0.002038701204583049 0.01939285546541214 0.5284103155136108 1.6904430389404297
+1.1380558013916016 0.010937624610960484 0.02375551499426365 0.34724828600883484 1.913482666015625
+1.0917484760284424 0.0019474180880934 0.015822358429431915 0.4003269374370575 1.6888692378997803
+1.111662745475769 0.00989029835909605 0.019786862656474113 0.45736101269721985 1.701277256011963
+1.0967061519622803 0.01454948354512453 0.024551132693886757 0.54127436876297 1.5558075904846191
+1.1212940216064453 0.0014588924823328853 0.014861701056361198 0.3137955963611603 1.8254673480987549
+1.1171743869781494 0.005027592647820711 0.016919488087296486 0.27669739723205566 1.7027837038040161
+1.0088475942611694 0.0028354499954730272 0.01621130481362343 0.4513593912124634 1.404582142829895
+1.0498324632644653 0.0017409332795068622 0.018349522724747658 0.3909494876861572 1.4894311428070068
+1.0528781414031982 0.002380440244451165 0.018442731350660324 0.4281589686870575 1.5017642974853516
+1.0849554538726807 0.001476967940106988 0.014998411759734154 0.36619773507118225 1.9234082698822021
+1.1087173223495483 0.001522251171991229 0.01674075238406658 0.5410787463188171 1.5902827978134155
+1.1295716762542725 0.0016785580664873123 0.016999393701553345 0.5232748985290527 1.381750464439392
+1.0287083387374878 0.0017437118804082274 0.015663743019104004 0.32849234342575073 1.5606920719146729
+1.127988338470459 0.0015012869844213128 0.015028993599116802 0.28445854783058167 2.0865719318389893
+1.1231026649475098 0.004538225941359997 0.02239559404551983 0.4057091772556305 1.7701820135116577
+1.0748192071914673 0.002376156160607934 0.023002950474619865 0.36016175150871277 1.396013855934143
+1.086715579032898 0.005053606815636158 0.01585347391664982 0.4595239460468292 1.6065765619277954
+1.1198186874389648 0.006988121196627617 0.01820918172597885 0.30464255809783936 2.0908396244049072
+0.9965840578079224 0.0026022715028375387 0.01956436224281788 0.40542399883270264 1.4955881834030151
+1.0579819679260254 0.00881793163716793 0.029842594638466835 0.6569209694862366 1.2267903089523315
+1.1107642650604248 0.0012825399171561003 0.014669548720121384 0.3191893696784973 1.891108751296997
+1.1471790075302124 0.0026258514262735844 0.017568064853549004 0.32754960656166077 2.032864809036255
+1.0213453769683838 0.002007835078984499 0.015282646752893925 0.4887363016605377 1.4269509315490723
+1.0931265354156494 0.00562114454805851 0.020990455523133278 0.4850589632987976 1.562624454498291
+1.0261989831924438 0.00198081461712718 0.01695270650088787 0.4861728549003601 1.4576663970947266
+1.0767395496368408 0.0032009698916226625 0.014783468097448349 0.27228471636772156 1.8250000476837158
+1.0604139566421509 0.0028240869287401438 0.015883738175034523 0.32586580514907837 1.7842727899551392
+1.0691931247711182 0.0016391959507018328 0.023694291710853577 0.535559356212616 1.3337403535842896
+1.0836269855499268 0.0014562864089384675 0.019199054688215256 0.41701048612594604 1.5043472051620483
+1.0295770168304443 0.002487274119630456 0.014824298210442066 0.2939029633998871 1.8663452863693237
+1.0799061059951782 0.001531071844510734 0.015226978808641434 0.39765286445617676 1.8915272951126099
+1.0057928562164307 0.0030097952112555504 0.026495279744267464 0.6735005974769592 1.0708324909210205
+1.0849107503890991 0.0015972843393683434 0.016238747164607048 0.4808959662914276 1.4195892810821533
+1.0348851680755615 0.004224908538162708 0.019056808203458786 0.31785741448402405 1.9191728830337524
+1.071874976158142 0.001967409159988165 0.013517364859580994 0.3197673559188843 1.8498069047927856
+1.0431199073791504 0.012847564183175564 0.04218384996056557 0.6190078258514404 1.189369559288025
+1.03846275806427 0.002768691396340728 0.015196716412901878 0.36810553073883057 1.4337677955627441
+1.0770134925842285 0.0049057938158512115 0.01872178167104721 0.31130117177963257 1.9055359363555908
+1.1354460716247559 0.0019385620253160596 0.014611366204917431 0.329734206199646 2.0778541564941406
+1.030577301979065 0.006505495868623257 0.02860126458108425 0.5439228415489197 1.2706328630447388
+1.1058695316314697 0.0028841972816735506 0.024709897115826607 0.6420814394950867 1.3688369989395142
+1.1055634021759033 0.006881528068333864 0.018828900530934334 0.25859567523002625 2.0349280834198
+1.081177830696106 0.004708428867161274 0.017424022778868675 0.3688071668148041 2.0009913444519043
+1.0486934185028076 0.005248645320534706 0.022566892206668854 0.5060778856277466 1.2541719675064087
+1.086539626121521 0.007425529882311821 0.02156110294163227 0.5919407606124878 1.2557473182678223
+1.1029890775680542 0.0013693058863282204 0.012998242862522602 0.23005042970180511 2.1397705078125
+0.9833292961120605 0.0013687681639567018 0.012135456316173077 0.3351563811302185 1.6534976959228516
+1.0808985233306885 0.0032437650952488184 0.020759977400302887 0.4565880000591278 1.3233340978622437
+1.0516523122787476 0.0020876408088952303 0.019207332283258438 0.44577038288116455 1.3711018562316895
+1.082880973815918 0.0031222598627209663 0.015158144757151604 0.3236899971961975 1.8599720001220703
+1.0393961668014526 0.002827298128977418 0.013020099140703678 0.24113832414150238 1.86500883102417
+0.9929193258285522 0.0019803352188318968 0.023904932662844658 0.47133293747901917 1.3182034492492676
+1.0287342071533203 0.0016602850519120693 0.021426506340503693 0.4468527138233185 1.385552167892456
+1.0622916221618652 0.004215427208691835 0.019054515287280083 0.3473481833934784 1.6970785856246948
+0.9407349228858948 0.0027511767111718655 0.013855752535164356 0.34656408429145813 1.620375394821167
+1.0691593885421753 0.0037739716935902834 0.01776796020567417 0.43570637702941895 1.594194769859314
+1.0636861324310303 0.007364504970610142 0.023410974070429802 0.6185684204101562 1.2667546272277832
+0.958365797996521 0.001335536944679916 0.013237705454230309 0.44223761558532715 1.4889702796936035
+1.058178186416626 0.003524224506691098 0.015166503377258778 0.2556624114513397 1.8856453895568848
+1.0401852130889893 0.0013966106344014406 0.013069300912320614 0.347356915473938 1.5873273611068726
+1.0922242403030396 0.005820612423121929 0.025680355727672577 0.5269951224327087 1.4438629150390625
+1.0465137958526611 0.002062099287286401 0.019125625491142273 0.3741437494754791 1.4686256647109985
+0.9798463582992554 0.004404818639159203 0.013863200321793556 0.25447574257850647 1.7048190832138062
+1.0436582565307617 0.0034381169825792313 0.013460781425237656 0.3169369399547577 1.8233226537704468
+1.0346992015838623 0.0022303827572613955 0.02176368609070778 0.5813802480697632 1.389739751815796
+0.9324182271957397 0.0027918145060539246 0.023274557664990425 0.5244656801223755 1.1549901962280273
+1.0524827241897583 0.0019258921965956688 0.013761661015450954 0.3567427396774292 1.7412368059158325
+1.0133185386657715 0.0021092321258038282 0.013825979083776474 0.2770753800868988 1.6761137247085571
+0.998403787612915 0.002889238065108657 0.014448112808167934 0.36652058362960815 1.4910119771957397
+1.068899154663086 0.0064499396830797195 0.02243446372449398 0.5341314673423767 1.2687382698059082
+0.9908277988433838 0.0013767880154773593 0.013597535900771618 0.2928416430950165 1.5675346851348877
+1.0073790550231934 0.0039678215980529785 0.014144480228424072 0.3203030526638031 1.6648266315460205
+1.0539255142211914 0.0019422026816755533 0.013751410879194736 0.4218195974826813 1.6678457260131836
+1.05500066280365 0.002417161362245679 0.017495092004537582 0.4729026257991791 1.470906138420105
+0.9880663752555847 0.002651684684678912 0.019288836047053337 0.4121253490447998 1.464309573173523
+1.0005669593811035 0.0014759371988475323 0.013780971989035606 0.35585498809814453 1.5264899730682373
+0.9783961772918701 0.0029010314028710127 0.013515218161046505 0.3683784306049347 1.6424773931503296
+0.9525818228721619 0.0013519618660211563 0.014153780415654182 0.4312732517719269 1.4888331890106201
+0.9590875506401062 0.003229040652513504 0.025942515581846237 0.48097124695777893 1.1548398733139038
+1.0040339231491089 0.0015391629422083497 0.013264351524412632 0.30779507756233215 1.6740081310272217
+0.9627550840377808 0.0014123953878879547 0.013057033531367779 0.3437970280647278 1.5021015405654907
+1.0271210670471191 0.0014136555837467313 0.01384130958467722 0.5032045245170593 1.375427007675171
+1.0380123853683472 0.0013023070059716702 0.012829100713133812 0.400203675031662 1.7017766237258911
+0.9567960500717163 0.002309231087565422 0.01644543744623661 0.412616103887558 1.3969658613204956
+0.9845996499061584 0.0016703043365851045 0.016216017305850983 0.3644047975540161 1.5829176902770996
+0.979526162147522 0.0016182473627850413 0.014947070740163326 0.44514569640159607 1.3955178260803223
+1.0704466104507446 0.0029653471428900957 0.013357937335968018 0.4231741726398468 1.672201156616211
+0.9513802528381348 0.0014489477034658194 0.013853052631020546 0.4211709499359131 1.3793808221817017
+0.9778217077255249 0.0018734891200438142 0.014441701583564281 0.3339956998825073 1.5854238271713257
+0.963698148727417 0.0024876808747649193 0.01516573503613472 0.44289857149124146 1.4628040790557861
+0.9501211643218994 0.0014431816525757313 0.012738530524075031 0.3449726998806 1.484366536140442
+0.9333972930908203 0.0014620432630181313 0.012915653176605701 0.31856822967529297 1.364823579788208
+0.9736983776092529 0.0013871041592210531 0.012658674269914627 0.31015563011169434 1.4015403985977173
+0.999091625213623 0.001425418653525412 0.012719390913844109 0.27253448963165283 1.6119236946105957
+0.9406182765960693 0.0033067623153328896 0.017547348514199257 0.3832342326641083 1.6021101474761963
+0.9734635949134827 0.002112827030941844 0.020147789269685745 0.4409318268299103 1.3763588666915894
+0.9581606984138489 0.0029778110329061747 0.01720551773905754 0.4027922451496124 1.3917932510375977
+0.9835249185562134 0.008128028362989426 0.01604468934237957 0.35443630814552307 1.5101227760314941
+0.8738287687301636 0.0018179221078753471 0.013835448771715164 0.3060528039932251 1.320487141609192
+0.9567054510116577 0.004586610943078995 0.014754511415958405 0.3456564247608185 1.453568696975708
+0.9322906732559204 0.00788749661296606 0.02295617386698723 0.27644801139831543 1.5090723037719727
+0.8632313013076782 0.0027483408339321613 0.012716105207800865 0.3439236283302307 1.1818965673446655
+0.945676326751709 0.00148597895167768 0.013203385286033154 0.39133763313293457 1.250290036201477
+1.0229682922363281 0.0017161977011710405 0.012026754207909107 0.37861451506614685 1.7275748252868652
+0.9418973922729492 0.002995074959471822 0.018532607704401016 0.3061807155609131 1.4374678134918213
+0.9419942498207092 0.00319282291457057 0.027922676876187325 0.4854450523853302 1.134648084640503
+0.8526065349578857 0.004254685249179602 0.01564531773328781 0.35814225673675537 1.3544343709945679
+0.9501774907112122 0.009382830932736397 0.015586121007800102 0.21946686506271362 1.627989649772644
+0.9404877424240112 0.0015037854900583625 0.015181832946836948 0.3767734169960022 1.4021687507629395
+0.998842716217041 0.005569326691329479 0.02346377819776535 0.41363099217414856 1.4506126642227173
+0.8995825052261353 0.0023331926204264164 0.015497306361794472 0.33277812600135803 1.2244230508804321
+0.9432270526885986 0.003154986770823598 0.01240011677145958 0.3042391836643219 1.5076301097869873
+0.9966434836387634 0.002144893631339073 0.01250485610216856 0.2633458971977234 1.7466790676116943
+1.0014771223068237 0.00501455832272768 0.01636834256350994 0.4195519983768463 1.3304847478866577
+0.9346894025802612 0.0058166407980024815 0.025553660467267036 0.47284138202667236 1.0195095539093018
+0.9182949066162109 0.001864771475084126 0.012181495316326618 0.2935747504234314 1.5769649744033813
+0.947900652885437 0.0035064532421529293 0.014540241099894047 0.27542003989219666 1.8140878677368164
+0.8921757340431213 0.002042587846517563 0.015956507995724678 0.550357460975647 1.06724214553833
+1.0532089471817017 0.002880214713513851 0.01726577617228031 0.5202788710594177 1.2100481986999512
+0.9345394372940063 0.0016502513317391276 0.01126081682741642 0.32731613516807556 1.5292688608169556
+0.9853217601776123 0.0017492948099970818 0.010708821937441826 0.30427995324134827 1.854966640472412
+0.9245256781578064 0.004568452946841717 0.01632457599043846 0.37873953580856323 1.3533800840377808
+0.9472746253013611 0.006161733530461788 0.02514042891561985 0.4121086299419403 1.1987583637237549
+0.969195544719696 0.0022830998059362173 0.01355645339936018 0.41171973943710327 1.5525532960891724
+0.8996765613555908 0.01049075834453106 0.019681641831994057 0.34426215291023254 1.4616363048553467
+0.9284722805023193 0.0030428075697273016 0.013106687925755978 0.3633573353290558 1.3920241594314575
+0.8304770588874817 0.0036836012732237577 0.0216402355581522 0.3906245231628418 1.2729343175888062
+0.9091548323631287 0.004336413461714983 0.020034050568938255 0.33495160937309265 1.3722455501556396
+0.9659157395362854 0.0015949491644278169 0.012708460912108421 0.40307900309562683 1.2640252113342285
+0.9238506555557251 0.004441536962985992 0.014119697734713554 0.3030731976032257 1.5562608242034912
+0.938460648059845 0.001364841707982123 0.01211674977093935 0.3078542649745941 1.5037837028503418
+0.9407138228416443 0.0070776138454675674 0.02054545097053051 0.4561208486557007 1.220456838607788
+0.9005295038223267 0.0035224196035414934 0.016781240701675415 0.31238916516304016 1.3092139959335327
+0.9362878799438477 0.002232789061963558 0.014015231281518936 0.25813809037208557 1.4842500686645508
+0.8159126043319702 0.0034214428160339594 0.013957960531115532 0.26825132966041565 1.1564325094223022
+0.9492248296737671 0.001329449936747551 0.013691454194486141 0.3553844690322876 1.3509020805358887
+0.9504994750022888 0.002911837538704276 0.01451141107827425 0.2685340940952301 1.5618754625320435
+0.9048738479614258 0.0031564536038786173 0.016724834218621254 0.36593809723854065 1.4115172624588013
+0.9617490768432617 0.0012868941994383931 0.01377292349934578 0.28221002221107483 1.318359613418579
+0.8790708184242249 0.0029395567253232002 0.014422711916267872 0.3503171503543854 1.344451665878296
+0.9360532164573669 0.0014038848457857966 0.013094635680317879 0.32263466715812683 1.5635868310928345
+0.8562812805175781 0.0036391483154147863 0.0176413394510746 0.44256648421287537 1.135407567024231
+0.8905801773071289 0.0024446528404951096 0.01679275929927826 0.42727693915367126 1.3110978603363037
+0.9433914422988892 0.0017742272466421127 0.012801778502762318 0.30495485663414 1.3341341018676758
+0.8134874105453491 0.003107962664216757 0.013062643818557262 0.3482359051704407 1.2228299379348755
+0.9310088753700256 0.0012248123530298471 0.010920357890427113 0.25074833631515503 1.5985815525054932
+0.8817087411880493 0.007441739086061716 0.019978612661361694 0.3563174307346344 1.3738597631454468
+0.8908129930496216 0.003897105809301138 0.021602271124720573 0.45435231924057007 1.0880215167999268
+0.9271447658538818 0.004715914838016033 0.012809326872229576 0.2579673230648041 1.5206515789031982
+0.9228119850158691 0.008461142890155315 0.015751535072922707 0.3718479871749878 1.4432170391082764
+0.9384498000144958 0.0012837614631280303 0.014364735223352909 0.3711279034614563 1.371433138847351
+0.888739287853241 0.0049672373570501804 0.019862515851855278 0.3499959409236908 1.3434340953826904
+0.8994122743606567 0.0028968423139303923 0.014266008511185646 0.3705671727657318 1.392702341079712
+0.8775621652603149 0.0015040031867101789 0.012059656903147697 0.4559771418571472 1.296574592590332
+0.9044053554534912 0.0019116790499538183 0.011732963845133781 0.38455942273139954 1.3534998893737793
+0.9113696813583374 0.0017084236023947597 0.010070182383060455 0.2576320767402649 1.377977728843689
+0.8987091779708862 0.005779910366982222 0.016393648460507393 0.3522628843784332 1.3429151773452759
+0.9022537469863892 0.0031494535505771637 0.016795946285128593 0.31368279457092285 1.4696502685546875
+0.9439229369163513 0.002342129359021783 0.014575169421732426 0.31280994415283203 1.559389591217041
+0.8871569037437439 0.00751432916149497 0.0164628978818655 0.43438729643821716 1.2319676876068115
+0.8921835422515869 0.005928869359195232 0.015131162479519844 0.32444512844085693 1.4130712747573853
+0.8322241306304932 0.0014440088998526335 0.016313476487994194 0.33096063137054443 1.2440654039382935
+0.8278947472572327 0.0029121108818799257 0.013896947726607323 0.35679659247398376 1.3634185791015625
+0.9202176332473755 0.0031614142935723066 0.011641060933470726 0.286685585975647 1.5036773681640625
+0.8763408660888672 0.0020144430454820395 0.010913167148828506 0.35853102803230286 1.2584643363952637
+0.7996342778205872 0.0012771678157150745 0.012269112281501293 0.3367709815502167 1.0889712572097778
+0.9054915308952332 0.0012476862175390124 0.010776292532682419 0.2667016386985779 1.491503357887268
+0.8933284282684326 0.0017600037390366197 0.013212810270488262 0.3408583104610443 1.3495408296585083
+0.7965114712715149 0.0021975941490381956 0.01573990471661091 0.4311363995075226 1.0904256105422974
+0.7979594469070435 0.0014747519744560122 0.013325294479727745 0.33536410331726074 1.209132194519043
+0.9306799173355103 0.002970682689920068 0.011884878389537334 0.2563645839691162 1.549918293952942
+0.8281104564666748 0.0014601474395021796 0.013913990929722786 0.28253409266471863 1.226894736289978
+0.868451714515686 0.002164307050406933 0.016361374408006668 0.2956831157207489 1.2452223300933838
+0.8530275225639343 0.0016928182449191809 0.014063980430364609 0.32359084486961365 1.3852657079696655
+0.8724654912948608 0.0016459289472550154 0.012755502946674824 0.3237471282482147 1.3282628059387207
+0.8440070748329163 0.00250982865691185 0.012491447851061821 0.40341663360595703 1.135308027267456
+0.8388121724128723 0.0014670330565422773 0.012066492810845375 0.3482251763343811 1.2221242189407349
+0.9133214950561523 0.0018797332886606455 0.012291964143514633 0.39696040749549866 1.4531068801879883
+0.8624118566513062 0.004401630721986294 0.014721433632075787 0.40338391065597534 1.3115348815917969
+0.8223276734352112 0.0021789362654089928 0.0117774223908782 0.37210705876350403 1.222045660018921
+0.8294132947921753 0.0013809408992528915 0.011303659528493881 0.29257240891456604 1.1953035593032837
+0.8703910708427429 0.002242538146674633 0.0113747613504529 0.4505036473274231 1.1469465494155884
+0.8825359344482422 0.0011357280891388655 0.010067432187497616 0.29474085569381714 1.5149966478347778
+0.8540067076683044 0.004713192582130432 0.016568688675761223 0.293872594833374 1.4726073741912842
+0.8613783717155457 0.003581950906664133 0.022124210372567177 0.4540007412433624 1.1377052068710327
+0.8976895213127136 0.004017508588731289 0.01252717524766922 0.3812847137451172 1.228798508644104
+0.8507541418075562 0.011220070533454418 0.01546835619956255 0.3062492907047272 1.359532117843628
+0.9170272350311279 0.0014167486224323511 0.010930491611361504 0.33716341853141785 1.5396469831466675
+0.86927330493927 0.010163248516619205 0.02172096073627472 0.40367263555526733 1.1069799661636353
+0.8402775526046753 0.006888070143759251 0.015813961625099182 0.33667442202568054 1.1906626224517822
+0.8710911273956299 0.0016219770768657327 0.010116071440279484 0.23380205035209656 1.4842644929885864
+0.7902107238769531 0.0033405404537916183 0.010822386480867863 0.295369952917099 1.2121323347091675
+0.8633126020431519 0.0014105550944805145 0.013719739392399788 0.3604172170162201 1.3338466882705688
+0.7997465133666992 0.0022098487243056297 0.02314262092113495 0.44465863704681396 0.9797321557998657
+0.8256331086158752 0.001896006055176258 0.01578989066183567 0.34923312067985535 1.294270396232605
+0.8502955436706543 0.004485825542360544 0.014037835411727428 0.261470228433609 1.4015151262283325
+0.8221156597137451 0.0013404260389506817 0.017210621386766434 0.31535226106643677 1.2492367029190063
+0.8717689514160156 0.0033900782000273466 0.014402057975530624 0.35878387093544006 1.1482267379760742
+0.8431717157363892 0.0031638462096452713 0.011237679980695248 0.25238531827926636 1.350293517112732
+0.7992390394210815 0.0013693530345335603 0.010144190862774849 0.3616628348827362 1.1564710140228271
+0.7992468476295471 0.0011162001173943281 0.010470988228917122 0.2784631848335266 1.2585325241088867
+0.8131712079048157 0.001401241635903716 0.011012920178472996 0.3692837953567505 1.1380952596664429
+0.8005406856536865 0.0018879659473896027 0.012136919423937798 0.31142017245292664 1.1679540872573853
+0.8747624158859253 0.0013806753559038043 0.011122413910925388 0.2843429148197174 1.4356613159179688
+0.9029581546783447 0.0014587931800633669 0.013197261840105057 0.4342752993106842 1.443468451499939
+0.7602533102035522 0.0015820900443941355 0.016315704211592674 0.31902700662612915 1.0282950401306152
+0.7588459253311157 0.0022813535761088133 0.013568447902798653 0.33395442366600037 1.230805516242981
+0.8402261137962341 0.0020342236384749413 0.013339439406991005 0.28900545835494995 1.3124961853027344
+0.7978499531745911 0.0013131988234817982 0.014255809597671032 0.3479897975921631 1.273046851158142
+0.8712509274482727 0.0017660841112956405 0.012866368517279625 0.3929423689842224 1.1860002279281616
+0.807334303855896 0.0014154233504086733 0.009749134071171284 0.2831491231918335 1.2645183801651
+0.8421046137809753 0.0012993256095796824 0.01074039377272129 0.25033989548683167 1.3850313425064087
+0.8487778902053833 0.0017236449057236314 0.011690723709762096 0.23211731016635895 1.2452564239501953
+0.8130440711975098 0.0014246131759136915 0.013797436840832233 0.36569151282310486 1.137976884841919
+0.8245829343795776 0.0014400901272892952 0.012311029247939587 0.4340997040271759 1.2309402227401733
+0.7634490728378296 0.0015909059438854456 0.012002609670162201 0.37797775864601135 1.343414306640625
+0.8624300360679626 0.0015092574758455157 0.016885889694094658 0.3986849784851074 1.3370585441589355
+0.8335155844688416 0.0014767907559871674 0.019540563225746155 0.46713048219680786 1.0349103212356567
+0.7928241491317749 0.003774474374949932 0.012850509025156498 0.32537731528282166 1.1248656511306763
+0.8109406232833862 0.004375062882900238 0.01374419778585434 0.29462364315986633 1.3660269975662231
+0.8057234287261963 0.0021076002158224583 0.01031474955379963 0.21743856370449066 1.3762067556381226
+0.871938169002533 0.012295442633330822 0.01971455290913582 0.34334221482276917 1.2010818719863892
+0.8115530014038086 0.004122960846871138 0.010966951958835125 0.3909497857093811 1.1687747240066528
+0.8384988903999329 0.0032355336006730795 0.01448874082416296 0.356032133102417 1.3902772665023804
+0.7962056398391724 0.0033027096651494503 0.013075358234345913 0.29384008049964905 1.2225273847579956
+0.78777015209198 0.0026340975891798735 0.012818660587072372 0.317107617855072 1.183811902999878
+0.8084424734115601 0.00686833169311285 0.018671125173568726 0.44010454416275024 1.1055939197540283
+0.8055517673492432 0.001586323487572372 0.01207085233181715 0.25895756483078003 1.2997738122940063
+0.803376317024231 0.004425056744366884 0.015104040503501892 0.27493247389793396 1.1646207571029663
+0.7779552936553955 0.003050954081118107 0.011091678403317928 0.362951397895813 1.0884432792663574
+0.7937660217285156 0.0019965339452028275 0.01381766889244318 0.26454880833625793 1.2073261737823486
+0.8019046783447266 0.0035808999091386795 0.01724414899945259 0.3280189633369446 1.2301712036132812
+0.79725182056427 0.0016061257338151336 0.011978322640061378 0.3537939786911011 1.2045387029647827
+0.8230899572372437 0.0036202736664563417 0.010121353901922703 0.2211632877588272 1.4109013080596924
+0.8017736077308655 0.0018799350364133716 0.011949223466217518 0.28793877363204956 1.2001721858978271
+0.8338784575462341 0.0016922455979511142 0.014281736686825752 0.3775290548801422 1.1217293739318848
+0.7743864059448242 0.0013983346289023757 0.011194972321391106 0.2121613621711731 1.2865973711013794
+0.7900530695915222 0.001343326410278678 0.010931693017482758 0.2749996781349182 1.2255229949951172
+0.8164499402046204 0.0014691484393551946 0.009985932148993015 0.30582448840141296 1.2131154537200928
+0.8119259476661682 0.001869994099251926 0.010050580836832523 0.2944113612174988 1.129643201828003
+0.7457248568534851 0.0016947418916970491 0.009173603728413582 0.35546988248825073 0.9792178869247437
+0.8228278160095215 0.0013158959336578846 0.009110731072723866 0.3010355532169342 1.3785134553909302
+0.7667663097381592 0.0017996751703321934 0.010371020063757896 0.3682135343551636 1.142073154449463
+0.809066891670227 0.001726713264361024 0.011744175106287003 0.3492065668106079 1.2240766286849976
+0.7706704139709473 0.0013714836677536368 0.011756440624594688 0.2865349352359772 1.1739237308502197
+0.7962474822998047 0.0023983377031981945 0.010609653778374195 0.39364296197891235 1.199324131011963
+0.838119387626648 0.0017678739968687296 0.011759859509766102 0.34176626801490784 1.238878846168518
+0.7708641886711121 0.0015520259039476514 0.015131867490708828 0.3453558385372162 1.0868221521377563
+0.757474958896637 0.0014903312548995018 0.016711026430130005 0.26630574464797974 1.1531097888946533
+0.7425395250320435 0.0014754609437659383 0.01340403500944376 0.2940783202648163 1.0824835300445557
+0.7435439825057983 0.0020949114114046097 0.012414596974849701 0.3269610106945038 1.085756540298462
+0.7815203666687012 0.0015113339759409428 0.010606087744235992 0.3123805522918701 1.1377558708190918
+0.7887165546417236 0.001278629177249968 0.010646844282746315 0.3399854004383087 1.1496143341064453
+0.839688241481781 0.002750528510659933 0.009683733806014061 0.32547158002853394 1.3270063400268555
+0.8545882701873779 0.0028332655783742666 0.011906399391591549 0.3193119466304779 1.3383808135986328
+0.8190232515335083 0.0012308666482567787 0.011720214039087296 0.3801184892654419 1.1704192161560059
+0.7839843034744263 0.00213335151784122 0.012878959067165852 0.24898113310337067 1.3760043382644653
+0.7972829937934875 0.002183690667152405 0.013346797786653042 0.32166916131973267 1.2378597259521484
+0.7807274460792542 0.0013398360460996628 0.013961996883153915 0.43935874104499817 1.0498549938201904
+0.745665431022644 0.0011599217541515827 0.011456611566245556 0.3580838143825531 1.1415109634399414
+0.7977756261825562 0.001078349188901484 0.009315671399235725 0.4063875079154968 1.1491632461547852
+0.8383263349533081 0.002532089129090309 0.008131354115903378 0.27972519397735596 1.2910850048065186
+0.6883704662322998 0.005802609957754612 0.008176038973033428 0.3266250491142273 0.9295578598976135
+0.8327978849411011 0.0035366343799978495 0.007761099375784397 0.3124848008155823 1.2823576927185059
+0.7615809440612793 0.00129890616517514 0.009562649764120579 0.2635955810546875 1.1990678310394287
+0.7739999294281006 0.002171917585656047 0.010913551785051823 0.31982749700546265 1.2483762502670288
+0.7460802793502808 0.0031553367152810097 0.013471676036715508 0.3428274989128113 1.170880675315857
+0.7293487787246704 0.003297027200460434 0.016945339739322662 0.3333563506603241 0.9663287997245789
+0.7857531309127808 0.0044010658748447895 0.012053669430315495 0.23026901483535767 1.3047027587890625
+0.7945495843887329 0.002229974139481783 0.010974494740366936 0.33783891797065735 1.241965651512146
+0.8004581332206726 0.001789837027899921 0.01189497485756874 0.4537099003791809 0.9932981729507446
+0.767146646976471 0.002844627480953932 0.008833564817905426 0.24502818286418915 1.2466777563095093
+0.7961668968200684 0.0033906269818544388 0.008356953971087933 0.2310490906238556 1.415889859199524
+0.8387801647186279 0.0028127292171120644 0.01342916302382946 0.4443298280239105 1.0386162996292114
+0.7694172263145447 0.00174017995595932 0.011488275602459908 0.3212776780128479 0.9917318820953369
+0.7786615490913391 0.005515003576874733 0.013332298956811428 0.2818613052368164 1.3520411252975464
+0.7871006727218628 0.0012812688946723938 0.011006316170096397 0.2245744913816452 1.4631130695343018
+0.8397105932235718 0.0060665188357234 0.02271832525730133 0.43212562799453735 0.9955604076385498
+0.7752143144607544 0.0012632622383534908 0.012213571928441525 0.39023464918136597 1.0452808141708374
+0.7143055200576782 0.006084675434976816 0.016528921201825142 0.28831496834754944 1.161577582359314
+0.730324923992157 0.0015668279957026243 0.011768964119255543 0.25715944170951843 1.2108368873596191
+0.7704471945762634 0.01133065763860941 0.013289802707731724 0.3317459523677826 1.0558538436889648
+0.7276421189308167 0.01353343017399311 0.018302347511053085 0.2962823510169983 0.8599514961242676
+0.7827227711677551 0.001324883778579533 0.0087349908426404 0.2574567198753357 1.3302885293960571
+0.7690423727035522 0.006985611282289028 0.01253086794167757 0.1933216005563736 1.3294975757598877
+0.8177518844604492 0.001994333229959011 0.01593560166656971 0.4940827190876007 1.0468956232070923
+0.853776216506958 0.001589559717103839 0.021978680044412613 0.4002068340778351 1.1858071088790894
+0.7333141565322876 0.001833922229707241 0.014602700248360634 0.3000571131706238 1.1248819828033447
+0.7368049025535583 0.0032581109553575516 0.01041931752115488 0.23039917647838593 1.2534726858139038
+0.7034114599227905 0.0011072504566982388 0.010486210696399212 0.26722198724746704 1.0640475749969482
+0.7607307434082031 0.004508954472839832 0.013533391989767551 0.30973362922668457 1.0653024911880493
+0.7061874866485596 0.0031232465989887714 0.011377043090760708 0.3388000428676605 0.9392197728157043
+0.6874368190765381 0.0011385041289031506 0.009604518301784992 0.26296594738960266 1.0843864679336548
+0.702731192111969 0.0012276067864149809 0.009456649422645569 0.3254876732826233 1.1580630540847778
+0.7775306701660156 0.0029630509670823812 0.012086843140423298 0.4225022494792938 1.1549509763717651
+0.7488006353378296 0.003086781594902277 0.016842147335410118 0.4303632378578186 0.9399311542510986
+0.7448853850364685 0.0025676016230136156 0.010698189027607441 0.341485857963562 1.1403299570083618
+0.7274191379547119 0.005846783518791199 0.01092612650245428 0.26500892639160156 1.2434602975845337
+0.7774208784103394 0.0013453755527734756 0.011691421270370483 0.3142072856426239 1.242121696472168
+0.6955479383468628 0.003014633199200034 0.017871703952550888 0.3515973687171936 0.9472256898880005
+0.7760064601898193 0.0012220772914588451 0.010745908133685589 0.4044463038444519 1.0282397270202637
+0.7775273323059082 0.0029536474030464888 0.010969330556690693 0.3034636378288269 1.2712198495864868
+0.7560648918151855 0.0011339265620335937 0.008170544169843197 0.22473624348640442 1.2819876670837402
+0.6755255460739136 0.009057121351361275 0.017074190080165863 0.3231717348098755 0.9474546313285828
+0.7121551036834717 0.00422686105594039 0.015452079474925995 0.348655641078949 0.8923670053482056
+0.676896870136261 0.003910087049007416 0.01212770864367485 0.37302181124687195 0.8737180233001709
+0.7011690139770508 0.0066377706825733185 0.01474473625421524 0.15348397195339203 1.260682463645935
+0.7355504035949707 0.003215935081243515 0.013069340027868748 0.2100490927696228 1.2468551397323608
+0.7785873413085938 0.013036226853728294 0.03578303009271622 0.3462064862251282 0.9297190308570862
+0.7143374085426331 0.001450518611818552 0.012933443300426006 0.3733006715774536 0.9291136860847473
+0.7481463551521301 0.01796196959912777 0.02124965190887451 0.17702481150627136 1.2040715217590332
+0.7982478737831116 0.00357172847725451 0.010914054699242115 0.1892162263393402 1.4590235948562622
+0.7468312978744507 0.015012979507446289 0.0316157341003418 0.4922407865524292 0.9049389362335205
+0.7107776403427124 0.010626336559653282 0.024786874651908875 0.35267943143844604 0.890088677406311
+0.6648107767105103 0.004574976395815611 0.01331239752471447 0.309863805770874 1.0000617504119873
+0.7368533611297607 0.012088608928024769 0.023704830557107925 0.20426328480243683 1.3808292150497437
+0.7109097242355347 0.002488484838977456 0.011119968257844448 0.2995573878288269 1.1128517389297485
+0.7815001010894775 0.018800845369696617 0.03362337499856949 0.46173977851867676 0.8559249043464661
+0.6931371688842773 0.00255604088306427 0.011781306006014347 0.2855652868747711 1.0557975769042969
+0.7989795207977295 0.008845013566315174 0.017681684345006943 0.2933961749076843 1.471011757850647
+0.7232208847999573 0.004334008786827326 0.012719273567199707 0.30387967824935913 1.1632157564163208
+0.7529793381690979 0.0038536626379936934 0.019785061478614807 0.42362239956855774 0.9338582158088684
+0.7262948155403137 0.005260475445538759 0.017226653173565865 0.43862348794937134 0.9208632707595825
+0.6648885607719421 0.0012688999995589256 0.00947241485118866 0.26360955834388733 1.1422845125198364
+0.6633646488189697 0.0027308957651257515 0.009919753298163414 0.22308509051799774 1.1632524728775024
+0.6689644455909729 0.0014013962354511023 0.013774806633591652 0.25937220454216003 0.9826259613037109
+0.7071108818054199 0.0026867035776376724 0.02046709880232811 0.42909321188926697 0.8298208117485046
+0.6622597575187683 0.0016118722269311547 0.010434315539896488 0.2560898959636688 1.087273359298706
+0.673980176448822 0.0035595896188169718 0.010317436419427395 0.21568064391613007 1.1895064115524292
+0.6873634457588196 0.0013200113316997886 0.013539284467697144 0.30591338872909546 0.9954992532730103
+0.7209731936454773 0.0015370786422863603 0.01768314652144909 0.4089638292789459 0.8544839024543762
+0.6699643731117249 0.001285190344788134 0.010913980193436146 0.3648134768009186 1.0061330795288086
+0.7337982654571533 0.0019473341526463628 0.011828041635453701 0.2511979043483734 1.2430058717727661
+0.7360092401504517 0.00135140772908926 0.009503873996436596 0.2966771423816681 1.1454806327819824
+0.7154949903488159 0.005138287786394358 0.011385737918317318 0.3042116165161133 0.9663556814193726
+0.6326080560684204 0.0027184742502868176 0.010055210441350937 0.2228945642709732 0.9444437623023987
+0.7139711380004883 0.0011426556156948209 0.010469326749444008 0.311149537563324 1.1492271423339844
+0.7100996971130371 0.0015212232246994972 0.011477994732558727 0.3804107904434204 1.0808947086334229
+0.6332249045372009 0.0010845967335626483 0.011087235063314438 0.3193210959434509 0.8917467594146729
+0.6989697217941284 0.0012837962713092566 0.010715286247432232 0.337988018989563 1.1797257661819458
+0.7173676490783691 0.001397354295477271 0.008874149061739445 0.3186389207839966 1.0554497241973877
+0.6863187551498413 0.0011073071509599686 0.009519138373434544 0.4111126959323883 1.0542491674423218
+0.6474237442016602 0.001144403126090765 0.0097268745303154 0.45354560017585754 0.8097855448722839
+0.7250284552574158 0.0010465539526194334 0.007971934042870998 0.31139546632766724 1.172694444656372
+0.7255516052246094 0.0022910067345947027 0.008342702873051167 0.20544341206550598 1.1830518245697021
+0.7047302722930908 0.0037365281023085117 0.013516854494810104 0.3167065978050232 0.9655742645263672
+0.7145401239395142 0.0012871114304289222 0.011300325393676758 0.40649527311325073 0.9180693626403809
+0.7227112054824829 0.005691499914973974 0.00868943426758051 0.3317727744579315 1.0693285465240479
+0.7044875621795654 0.0036253645084798336 0.008572456426918507 0.2231060415506363 1.2014884948730469
+0.6622821688652039 0.0020183043088763952 0.01418671477586031 0.35840412974357605 0.8794580101966858
+0.6848524808883667 0.003030146239325404 0.01649906486272812 0.3123398721218109 0.9816902875900269
+0.6440466642379761 0.0013100234791636467 0.008416612632572651 0.22784027457237244 1.0560331344604492
+0.7047592401504517 0.0027741899248212576 0.009799504652619362 0.35386478900909424 1.1202406883239746
+0.66963791847229 0.0011344488011673093 0.00892073754221201 0.3312074840068817 0.94609135389328
+0.6810587644577026 0.002164402510970831 0.010839518159627914 0.2817712724208832 0.9286237955093384
+0.7076109051704407 0.0024704698007553816 0.009398349560797215 0.3038281798362732 1.1086375713348389
+0.708882212638855 0.001298387534916401 0.009550293907523155 0.27724766731262207 1.1758980751037598
+0.6557642221450806 0.0011416901834309101 0.012058948166668415 0.2793331742286682 0.8881343603134155
+0.6126943826675415 0.0016721256542950869 0.010331091471016407 0.27561870217323303 0.8976156115531921
+0.7135452032089233 0.00154692807700485 0.010401838459074497 0.32701367139816284 1.174601674079895
+0.665656566619873 0.0011455266503617167 0.013503859750926495 0.2604423463344574 1.056106686592102
+0.7366831302642822 0.0014359221095219254 0.011391602456569672 0.31820592284202576 1.2037259340286255
+0.6939690113067627 0.0011340799974277616 0.011293704621493816 0.3072962164878845 0.9206767082214355
+0.7024206519126892 0.0012748617446050048 0.01066314522176981 0.3991597890853882 0.8280844688415527
+0.6624963283538818 0.0011188257485628128 0.01007548626512289 0.20758259296417236 1.1514374017715454
+0.6399281024932861 0.004015413578599691 0.00867296103388071 0.17547693848609924 1.0828183889389038
+0.7029967904090881 0.00950993224978447 0.021913960576057434 0.28478869795799255 0.9651790857315063
+0.6456894278526306 0.00128131581004709 0.012840267270803452 0.5137242078781128 0.846096932888031
+0.6783156394958496 0.01288775634020567 0.012941296212375164 0.25319012999534607 1.0627038478851318
+0.6715829372406006 0.009561439044773579 0.011646660976111889 0.18692296743392944 1.0487446784973145
+0.7102094888687134 0.002051125979050994 0.012435168959200382 0.2006291002035141 1.1807026863098145
+0.6247101426124573 0.011300645768642426 0.02272891066968441 0.281913161277771 0.8690866827964783
+0.6579247713088989 0.002029157942160964 0.008595732972025871 0.2707090377807617 1.0096663236618042
+0.7242125868797302 0.0042111375369131565 0.011723380535840988 0.37411537766456604 1.1871936321258545
+0.6852309703826904 0.0024598105810582638 0.010140540078282356 0.31558629870414734 1.0169668197631836
+0.6961726546287537 0.0022444783244282007 0.01057793851941824 0.4119569659233093 0.9301369786262512
+0.6696659922599792 0.003916672430932522 0.011367838829755783 0.27817511558532715 0.9413745999336243
+0.671131432056427 0.0010948169510811567 0.008127294480800629 0.18258048593997955 1.1196463108062744
+0.6314627528190613 0.0013202549889683723 0.008852953091263771 0.269340842962265 0.9819831848144531
+0.6873130202293396 0.0010857084998860955 0.01041335891932249 0.24080149829387665 1.017618179321289
+0.6873871684074402 0.0011856612982228398 0.010208560153841972 0.33627986907958984 0.8097127676010132
+0.7055980563163757 0.0014383570523932576 0.008187251165509224 0.23567545413970947 1.1568185091018677
+0.6755511164665222 0.0010704232845455408 0.008371001109480858 0.18722935020923615 1.2424991130828857
+0.6002792119979858 0.0029854883905500174 0.01220513042062521 0.24592703580856323 0.8975947499275208
+0.6682871580123901 0.0026333096902817488 0.012066656723618507 0.27740734815597534 0.8500942587852478
+0.698694109916687 0.0018807367887347937 0.008786646649241447 0.3170251250267029 0.8840293288230896
+0.6497287750244141 0.00309299910441041 0.008978118188679218 0.21555423736572266 1.0660074949264526
+0.6785869598388672 0.0014807500410825014 0.01062911655753851 0.2191060185432434 1.1564222574234009
+0.6796467304229736 0.006467883475124836 0.025747252628207207 0.35558217763900757 1.0126720666885376
+0.701000452041626 0.0014214320108294487 0.019852520897984505 0.5416061878204346 0.8205257058143616
+0.6413305997848511 0.010597367770969868 0.012832919135689735 0.23506468534469604 0.9261528253555298
+0.6630938053131104 0.00910319946706295 0.013376601971685886 0.2512706220149994 1.1971906423568726
+0.5912479758262634 0.0032270641531795263 0.014541604556143284 0.25503483414649963 0.9547145366668701
+0.7057960033416748 0.014477815479040146 0.02611992135643959 0.3315737843513489 0.8925693035125732
+0.6517613530158997 0.0018219486810266972 0.009762457571923733 0.2101263850927353 1.0146735906600952
+0.6698386669158936 0.00579047529026866 0.016492322087287903 0.2513541579246521 1.0457007884979248
+0.6845617294311523 0.0031102471984922886 0.013464048504829407 0.3214496374130249 0.9632224440574646
+0.6681456565856934 0.00369078665971756 0.010995055548846722 0.415677547454834 0.9903425574302673
+0.6514320969581604 0.009669085033237934 0.016609953716397285 0.23063874244689941 1.070285677909851
+0.6058160066604614 0.0021036022808402777 0.010726157575845718 0.3020743727684021 0.8996824026107788
+0.6560643911361694 0.004342993255704641 0.012452098540961742 0.23054198920726776 1.0355886220932007
+0.5806335210800171 0.004880113992840052 0.012188472785055637 0.24710610508918762 0.7344155311584473
+0.6136050224304199 0.0009484650800004601 0.033578816801309586 0.3146684467792511 0.8520517349243164
+0.6550614833831787 0.004163133446127176 0.011022194288671017 0.25949162244796753 1.0406492948532104
+0.6191328167915344 0.004171549342572689 0.009959862567484379 0.22576409578323364 1.0536466836929321
+0.6519684195518494 0.0011809399584308267 0.008959519676864147 0.22051841020584106 1.0350425243377686
+0.6283367276191711 0.002102043479681015 0.009741587564349174 0.3348528742790222 0.7588825225830078
+0.68300861120224 0.003903054865077138 0.008385141380131245 0.2154078185558319 0.9951320886611938
+0.6120598316192627 0.001341388444416225 0.009693223051726818 0.24638885259628296 1.0154945850372314
+0.6509910821914673 0.0020468845032155514 0.014747721143066883 0.30157992243766785 1.0998592376708984
+0.6442064046859741 0.001814301242120564 0.01758234016597271 0.2875243127346039 0.8457469344139099
+0.6733372807502747 0.0024322227109223604 0.01069115474820137 0.32294100522994995 0.9655081033706665
+0.6604498624801636 0.005722575820982456 0.010905683040618896 0.20649172365665436 1.0818833112716675
+0.6204640865325928 0.0009889333741739392 0.007662871386855841 0.1853804737329483 0.9172659516334534
+0.6287599802017212 0.006232894957065582 0.010104349814355373 0.24285942316055298 0.9649905562400818
+0.6637967824935913 0.007733535021543503 0.010207951068878174 0.2710855007171631 1.1100685596466064
+0.6021167635917664 0.0013023456558585167 0.008343874476850033 0.3237573206424713 0.8560200333595276
+0.6046200394630432 0.003303548553958535 0.011114995926618576 0.2573586404323578 0.9072754979133606
+0.6210378408432007 0.0027948766946792603 0.009832552634179592 0.292673796415329 0.9395859241485596
+0.6877794861793518 0.0014096886152401567 0.009400570765137672 0.30001187324523926 1.0206444263458252
+0.6194762587547302 0.003999102395027876 0.012514453381299973 0.23353999853134155 0.989738404750824
+0.5833878517150879 0.0014904845738783479 0.00860985554754734 0.30099785327911377 0.7978940606117249
+0.6509503722190857 0.002333794254809618 0.013478195294737816 0.24555452167987823 0.9418787360191345
+0.5865160822868347 0.0021232443396002054 0.007406008895486593 0.19975806772708893 0.9501368999481201
+0.6269460320472717 0.0020596995018422604 0.011624329723417759 0.22932785749435425 0.9220640063285828
+0.6449941992759705 0.003755055833607912 0.01935001090168953 0.24171167612075806 0.9883286952972412
+0.6543258428573608 0.0012977153528481722 0.014004530385136604 0.3627999424934387 0.9091910123825073
+0.5962564945220947 0.007141549605876207 0.010577132925391197 0.24920706450939178 0.9175602793693542
+0.5765858888626099 0.005340190138667822 0.01001595240086317 0.22534222900867462 0.8884676098823547
+0.6428625583648682 0.00122581678442657 0.00814421009272337 0.24851827323436737 0.9766395092010498
+0.6820006966590881 0.006685267202556133 0.01040020864456892 0.2690068483352661 1.1370195150375366
+0.6634711027145386 0.005395456682890654 0.010721714235842228 0.2786504030227661 0.9070071578025818
+0.599478006362915 0.0009728160803206265 0.00960820447653532 0.303954154253006 0.728125810623169
+0.6598824858665466 0.003488146234303713 0.015217097476124763 0.2445612996816635 1.2292449474334717
+0.6666401624679565 0.0014041791437193751 0.008437830954790115 0.3514452874660492 1.0898993015289307
+0.6202415227890015 0.007367835380136967 0.022261934354901314 0.4064832031726837 0.8392665982246399
+0.627206027507782 0.0011213290272280574 0.012318712659180164 0.3608904778957367 0.8504062294960022
+0.6664971113204956 0.007819078862667084 0.013874072581529617 0.22238336503505707 1.2190918922424316
+0.6607227921485901 0.0036967997439205647 0.009995298460125923 0.28076452016830444 1.0392347574234009
+0.6127805709838867 0.003563936799764633 0.011900578625500202 0.29077666997909546 0.7890543341636658
+0.5837047100067139 0.006378915160894394 0.011365504004061222 0.31435340642929077 0.7568070888519287
+0.674202561378479 0.0011808081762865186 0.00615977356210351 0.2027687430381775 1.1693795919418335
+0.6349570751190186 0.0012479209108278155 0.006824570707976818 0.200816810131073 1.0491331815719604
+0.5769265294075012 0.001074649509973824 0.008150175213813782 0.2637510597705841 0.7986965179443359
+0.6555874347686768 0.0012936405837535858 0.011939825490117073 0.3703126311302185 1.081670880317688
+0.6040211915969849 0.0014419378712773323 0.010799625888466835 0.2469271570444107 0.9799614548683167
+0.6520620584487915 0.0036031468771398067 0.008798767812550068 0.21436849236488342 1.0324033498764038
+0.6225029826164246 0.002448274753987789 0.008516901172697544 0.3041876554489136 0.9225970506668091
+0.6291906833648682 0.0010452962014824152 0.00871450174599886 0.2565619647502899 0.967940092086792
+0.6366934776306152 0.002576639875769615 0.007745940703898668 0.20964550971984863 0.8697614669799805
+0.6448242664337158 0.0024833264760673046 0.00675581069663167 0.2491995096206665 1.0218133926391602
+0.6211211681365967 0.0020437357015907764 0.007910181768238544 0.25007426738739014 0.9663841724395752
+0.5920979380607605 0.0013425600482150912 0.008946025744080544 0.1933193802833557 0.9180220365524292
+0.6095212697982788 0.0011288796085864305 0.008987118490040302 0.25220417976379395 0.8518081903457642
+0.6512099504470825 0.002327561378479004 0.010019704699516296 0.2896636426448822 0.8432323336601257
+0.6543878316879272 0.002843148075044155 0.008491798304021358 0.24883097410202026 1.1651437282562256
+0.6287339925765991 0.0010897662723436952 0.03587450087070465 0.26395729184150696 1.0580750703811646
+0.6504688858985901 0.001478869467973709 0.024160034954547882 0.3243662714958191 0.7878902554512024
+0.6695401668548584 0.0008767988765612245 0.007425805553793907 0.32297661900520325 0.9832428097724915
+0.6196022629737854 0.0009961762698367238 0.007617373950779438 0.21197929978370667 1.00998055934906
+0.6348196864128113 0.0017740089679136872 0.0064168209210038185 0.21855345368385315 1.1007364988327026
+0.5920163989067078 0.006011504214257002 0.013431592844426632 0.3072604835033417 0.9400058388710022
+0.6755110025405884 0.0014136137906461954 0.011895076371729374 0.354144424200058 0.8489648103713989
+0.5640225410461426 0.005224265158176422 0.012629657983779907 0.30762210488319397 0.6950331330299377
+0.6673687696456909 0.00759726669639349 0.01375812292098999 0.18754152953624725 1.2705100774765015
+0.6244409084320068 0.0017670044908300042 0.009246999397873878 0.18909375369548798 0.9964756965637207
+0.6406119465827942 0.011452843435108662 0.02236742340028286 0.32394590973854065 0.8414627909660339
+0.6339374780654907 0.001691170735284686 0.008603723719716072 0.28229641914367676 0.8481914401054382
+0.6550008654594421 0.0062922644428908825 0.014040823094546795 0.22231777012348175 1.1038274765014648
+0.6646418571472168 0.0024650441482663155 0.008601713925600052 0.27220603823661804 1.1291381120681763
+0.6116758584976196 0.0062520550563931465 0.018081920221447945 0.3244520425796509 0.8895716071128845
+0.5966089963912964 0.007347988896071911 0.024429965764284134 0.3450669050216675 0.8007300496101379
+0.5657439827919006 0.00291219144128263 0.008410345762968063 0.18304671347141266 0.7623265981674194
+0.6023731231689453 0.013679162599146366 0.017486393451690674 0.3133845031261444 0.9085643291473389
+0.5904020071029663 0.0013276428217068315 0.0069846948608756065 0.2060845047235489 1.0382381677627563
+0.6278192400932312 0.013053594157099724 0.02075948566198349 0.24454906582832336 0.8566851019859314
+0.6385711431503296 0.010794622823596 0.020718323066830635 0.2662399411201477 0.8043386936187744
+0.6102945804595947 0.0017597185214981437 0.007709468714892864 0.16879287362098694 0.9588382244110107
+0.6014348268508911 0.010204408317804337 0.014466610737144947 0.20008547604084015 1.0038427114486694
+0.6381391286849976 0.0015243426896631718 0.007349784020334482 0.2474256455898285 0.9889337420463562
+0.6027135848999023 0.005233195144683123 0.018859200179576874 0.37922465801239014 0.7580824494361877
+0.6483604907989502 0.0030266372486948967 0.013389154337346554 0.3693782091140747 0.8934231996536255
+0.6115455031394958 0.0022062284406274557 0.007755814120173454 0.1843452751636505 0.984408438205719
+0.6470574736595154 0.002864431357011199 0.009593289345502853 0.2109789252281189 1.1501246690750122
+0.564699649810791 0.0020940841641277075 0.00978765171021223 0.3024391531944275 0.8389268517494202
+0.6112203598022461 0.004814459942281246 0.016921989619731903 0.2462400645017624 0.7778354287147522
+0.5806071162223816 0.0010017630411311984 0.009236374869942665 0.23967967927455902 0.816459596157074
+0.6363288760185242 0.005348813254386187 0.010630348697304726 0.20832796394824982 1.1046383380889893
+0.6204290390014648 0.0020106746815145016 0.007238632068037987 0.1902911216020584 1.0536264181137085
+0.6367231607437134 0.0029534839559346437 0.01315245870500803 0.3256777822971344 0.8875450491905212
+0.6065154075622559 0.0033122808672487736 0.01367876585572958 0.31898197531700134 0.7937415242195129
+0.582554817199707 0.001194523647427559 0.007141938898712397 0.18450425565242767 0.9116330146789551
+0.6086462140083313 0.0026777826715260744 0.009750219993293285 0.2557189166545868 1.0520472526550293
+0.5579041838645935 0.0011895596981048584 0.007809760514646769 0.2098519653081894 0.8948476314544678
+0.6361044645309448 0.004777492955327034 0.015037855133414268 0.25649967789649963 0.8110824823379517
+0.6814529299736023 0.0016800769371911883 0.011148909106850624 0.399058073759079 0.9946289658546448
+0.575928807258606 0.00263788178563118 0.008330275304615498 0.1667645424604416 0.9201916456222534
+0.6027486324310303 0.003282708115875721 0.008461314253509045 0.228971928358078 0.9888045191764832
+0.5699206590652466 0.0011827758280560374 0.008049875497817993 0.33283451199531555 0.8946889042854309
+0.6208285093307495 0.004233190789818764 0.015295504592359066 0.2651101052761078 0.8597930669784546
+0.5612775087356567 0.0012478948337957263 0.009050245396792889 0.2512233257293701 0.7480809092521667
+0.6304036378860474 0.0034234679769724607 0.009133309125900269 0.18180406093597412 0.9926854372024536
+0.5409549474716187 0.0022401656024158 0.007301073055714369 0.24433577060699463 0.9023473858833313
+0.6341080069541931 0.0017161653377115726 0.010326745919883251 0.2629709541797638 1.0757054090499878
+0.574053168296814 0.0035805024672299623 0.01550269965082407 0.3557387590408325 0.6430301070213318
+0.5944669246673584 0.0011101366253569722 0.007098070811480284 0.32184481620788574 0.7862202525138855
+0.5319759845733643 0.004338660277426243 0.01236083172261715 0.23415730893611908 0.8931224942207336
+0.6553504467010498 0.0008314242004416883 0.00586344301700592 0.2865683138370514 1.078943133354187
+0.6343352794647217 0.008756815455853939 0.014834470115602016 0.33046838641166687 0.861448347568512
+0.5750882029533386 0.005668337456882 0.01645256206393242 0.30195754766464233 0.7952251434326172
+0.5788866281509399 0.0023021160159260035 0.007620890159159899 0.21132445335388184 0.9062602519989014
+0.5799643993377686 0.01050910446792841 0.012768536806106567 0.24355052411556244 0.8409299850463867
+0.6153682470321655 0.002918397542089224 0.008706587366759777 0.275510311126709 1.0443397760391235
+0.580915093421936 0.004520075395703316 0.014083814807236195 0.19713644683361053 0.9521878361701965
+0.5779366493225098 0.007900341413915157 0.017626848071813583 0.3142835795879364 0.7687413096427917
+0.5603711605072021 0.0008608726784586906 0.007078849244862795 0.20478710532188416 0.8139731884002686
+0.5356489419937134 0.004716170486062765 0.012900789268314838 0.23532582819461823 0.8054733872413635
+0.5829275250434875 0.0013078757328912616 0.007203069981187582 0.16339169442653656 1.0101444721221924
+0.5528968572616577 0.006754937116056681 0.013134848326444626 0.30942627787590027 0.8097187876701355
+0.5499745011329651 0.008621878921985626 0.02299213781952858 0.3693782091140747 0.7294020652770996
+0.5679960250854492 0.0013059412594884634 0.009082618169486523 0.24141111969947815 0.8871358036994934
+0.5467170476913452 0.010294211097061634 0.0075630140490829945 0.1639786809682846 0.8981862664222717
+0.583716094493866 0.007114111445844173 0.008241643197834492 0.29802238941192627 0.8953153491020203
+0.5203300714492798 0.0012100031599402428 0.013185154646635056 0.2985410988330841 0.7241254448890686
+0.5825397372245789 0.0021585423965007067 0.01225263625383377 0.31913259625434875 0.7479796409606934
+0.5886530876159668 0.0010554754408076406 0.0060312882997095585 0.14441430568695068 1.0028245449066162
+0.5851801633834839 0.0009422664879821241 0.0060814726166427135 0.2359970211982727 0.9306648373603821
+0.5518983602523804 0.0013389167143031955 0.006841327529400587 0.1912374347448349 0.821953296661377
+0.5563392639160156 0.001667147851549089 0.007862495258450508 0.29564785957336426 0.6734204888343811
+0.5973297357559204 0.0008316677995026112 0.007119305431842804 0.19727103412151337 0.9068447947502136
+0.5259276628494263 0.0011290867114439607 0.008417848497629166 0.24669669568538666 0.9145501255989075
+0.5521855354309082 0.0008368490962311625 0.007257586345076561 0.29206162691116333 0.8869104385375977
+0.5685075521469116 0.002088971668854356 0.009258241392672062 0.23241621255874634 0.7895500063896179
+0.540971040725708 0.0009624917875044048 0.008388357236981392 0.24099507927894592 0.6971367597579956
+0.5669121742248535 0.001991753466427326 0.0075117251835763454 0.17126582562923431 0.8678876161575317
+0.5769084095954895 0.001245003310032189 0.00652722455561161 0.1796470582485199 0.9843317866325378
+0.566279947757721 0.00265316478908062 0.00771125964820385 0.2734629213809967 0.9293200969696045
+0.5414159297943115 0.004914373159408569 0.015631267800927162 0.3453535735607147 0.6743471622467041
+0.5409266948699951 0.0009583669598214328 0.0073321214877069 0.225369393825531 0.7566177248954773
+0.5064218044281006 0.004857463296502829 0.00698331231251359 0.22075720131397247 0.8507241606712341
+0.5806394815444946 0.0022875445429235697 0.007252247538417578 0.22074195742607117 0.9534627795219421
+0.596484899520874 0.0017228289507329464 0.016671771183609962 0.3080858588218689 0.859622597694397
+0.5650789737701416 0.0015090794768184423 0.01486379373818636 0.24870197474956512 0.7713867425918579
+0.5442870855331421 0.002254609251394868 0.007530102506279945 0.20479710400104523 0.8109661340713501
+0.5314152240753174 0.0035481397062540054 0.009523024782538414 0.21795283257961273 0.8900337815284729
+0.5439499020576477 0.0012905561598017812 0.016121139749884605 0.1862151026725769 0.9515467882156372
+0.5752958059310913 0.00882084108889103 0.01441959384828806 0.30324921011924744 0.6467620730400085
+0.578612208366394 0.0032247111666947603 0.00902671180665493 0.42988908290863037 0.7487800717353821
+0.5146468877792358 0.002392408438026905 0.009142894297838211 0.2023979127407074 0.9242107272148132
+0.557770848274231 0.0037885953206568956 0.009493998251855373 0.1456524282693863 1.0260834693908691
+0.5768649578094482 0.0009566014050506055 0.009512732736766338 0.2565256357192993 0.7878630757331848
+0.589203417301178 0.0021598003804683685 0.011612746864557266 0.3474563658237457 0.7498514652252197
+0.5753374099731445 0.0008436134667135775 0.006619553547352552 0.2539622485637665 0.8720323443412781
+0.563667356967926 0.0013201296096667647 0.007939320057630539 0.21886427700519562 0.9804173111915588
+0.5680645108222961 0.0013542743399739265 0.005337587092071772 0.1282082051038742 0.9839159846305847
+0.5457258820533752 0.005509554408490658 0.012464110739529133 0.3319498300552368 0.6644082069396973
+0.5674055814743042 0.0015411855420097709 0.009220870211720467 0.28839462995529175 0.7476683259010315
+0.5906074047088623 0.0035917097702622414 0.0071301935240626335 0.20556941628456116 0.9404321312904358
+0.5929552316665649 0.004522968549281359 0.006998504512012005 0.20527642965316772 1.0501220226287842
+0.5304490327835083 0.0010323983151465654 0.010745852254331112 0.24847978353500366 0.7446313500404358
+0.5045570731163025 0.002734503475949168 0.015098385512828827 0.24025429785251617 0.7088225483894348
+0.5611063241958618 0.0009356292430311441 0.008296513929963112 0.23629820346832275 0.8953869938850403
+0.564712405204773 0.002072135917842388 0.006528542842715979 0.21043427288532257 0.9943848848342896
+0.6041325330734253 0.0012617032043635845 0.006937406491488218 0.2106771618127823 0.8931520581245422
+0.5512865781784058 0.0010992102324962616 0.008134006522595882 0.2784532606601715 0.7226155996322632
+0.5676004886627197 0.0017282404005527496 0.0062113874591887 0.31261616945266724 0.7931832075119019
+0.5241379737854004 0.0013168794102966785 0.006540270987898111 0.21471896767616272 0.9447513222694397
+0.5407570600509644 0.0010832063853740692 0.007772402372211218 0.2709368169307709 0.8629182577133179
+0.5565921068191528 0.0009467512718401849 0.008509676903486252 0.1982826590538025 0.7684630155563354
+0.5156737565994263 0.002096630400046706 0.00697809923440218 0.2154335230588913 0.7716516256332397
+0.5262911915779114 0.002098335884511471 0.0066704838536679745 0.22958378493785858 0.7809070348739624
+0.5597944259643555 0.0008752762223593891 0.006354962941259146 0.18353964388370514 0.9823623299598694
+0.5212709307670593 0.0029210925567895174 0.010752332396805286 0.17958705127239227 0.795875072479248
+0.565170407295227 0.0017520198598504066 0.009132846258580685 0.18025203049182892 0.8881245851516724
+0.6252115964889526 0.0013178953668102622 0.006971016526222229 0.29885393381118774 0.8629202842712402
+0.5270929336547852 0.0031614424660801888 0.009270688518881798 0.31309592723846436 0.7441552877426147
+0.5097391605377197 0.0009661880903877318 0.0071489824913442135 0.273326575756073 0.7336933612823486
+0.5673655271530151 0.0030046000611037016 0.007714651059359312 0.16197460889816284 0.9958121180534363
+0.5061086416244507 0.005460082553327084 0.01112235989421606 0.3079473376274109 0.8097292184829712
+0.5186704397201538 0.00102539814542979 0.008104916661977768 0.2222120761871338 0.7046389579772949
+0.5417415499687195 0.0037029271479696035 0.008872650563716888 0.2832912802696228 0.7958605289459229
+0.5051677227020264 0.003039747476577759 0.008796874433755875 0.1333722174167633 0.8195461630821228
+0.5462812185287476 0.0017435085028409958 0.007060538046061993 0.18639294803142548 0.8758459687232971
+0.6013286113739014 0.00700259767472744 0.01676390878856182 0.3670854866504669 0.8456767797470093
+0.5337436199188232 0.0015954343834891915 0.010680770501494408 0.2828328609466553 0.6722241044044495
+0.5348972082138062 0.0052874344401061535 0.007375027518719435 0.21664641797542572 0.9042351841926575
+0.5254679322242737 0.005702228285372257 0.007931185886263847 0.15060964226722717 0.9482132792472839
+0.5517773032188416 0.0010116322664543986 0.009252838790416718 0.2786235213279724 0.8493578433990479
+0.5701507329940796 0.003904053708538413 0.015296896919608116 0.45537659525871277 0.7095529437065125
+0.502131462097168 0.0010028415126726031 0.006324457935988903 0.20311762392520905 0.6902198791503906
+0.5162074565887451 0.0018108466174453497 0.007262976374477148 0.19841471314430237 0.9127283692359924
+0.5213631391525269 0.0008242126787081361 0.005554430652409792 0.19626417756080627 0.829563319683075
+0.49276620149612427 0.003917154856026173 0.011648798361420631 0.17382603883743286 0.7075635194778442
+0.557860791683197 0.002832127036526799 0.013381697237491608 0.19615377485752106 0.740412712097168
+0.563413143157959 0.0016579730436205864 0.007323533296585083 0.2530095875263214 0.9693640470504761
+0.500590443611145 0.005386468023061752 0.00843183882534504 0.17604298889636993 0.8342151641845703
+0.5345258116722107 0.0028719597030431032 0.009841050952672958 0.2553044259548187 0.7926786541938782
+0.4756607115268707 0.0009958390146493912 0.009382732212543488 0.2753627896308899 0.6147152185440063
+0.49470680952072144 0.0015567988157272339 0.007142792921513319 0.19205139577388763 0.7355289459228516
+0.5384889841079712 0.0015339009696617723 0.005504470318555832 0.15416330099105835 0.8369232416152954
+0.5747548341751099 0.0017207551281899214 0.006107202731072903 0.1899004727602005 0.9692280292510986
+0.5588374137878418 0.0015842878492549062 0.00850837491452694 0.2545757293701172 0.7925829887390137
+0.5256282687187195 0.0009608964901417494 0.008421079255640507 0.33734452724456787 0.6023041605949402
+0.5314096212387085 0.0036321706138551235 0.008624549955129623 0.1533152461051941 0.8737915754318237
+0.5444780588150024 0.0017114453949034214 0.007511555682867765 0.1408800333738327 1.0037176609039307
+0.6118003129959106 0.00237596919760108 0.008572963997721672 0.28879696130752563 0.919086754322052
+0.5315595865249634 0.004764995072036982 0.01178294699639082 0.34183552861213684 0.6642671823501587
+0.5402306914329529 0.000745185767300427 0.005846831947565079 0.21635395288467407 0.821296215057373
+0.5741156339645386 0.0024397142697125673 0.010424482636153698 0.21819540858268738 0.9168753027915955
+0.5685369968414307 0.0007808200316503644 0.00656471773982048 0.23402753472328186 0.9061410427093506
+0.5395288467407227 0.0035059929359704256 0.00908276904374361 0.18697582185268402 0.715398371219635
+0.5231799483299255 0.002828507451340556 0.010938749648630619 0.2966611087322235 0.7127745747566223
+0.5468047857284546 0.0010059045162051916 0.005842393264174461 0.16120831668376923 0.9250776767730713
+0.5066219568252563 0.0025305384770035744 0.006042013876140118 0.17613936960697174 0.8847292065620422
+0.5482131242752075 0.0010557544883340597 0.006638516206294298 0.2535366117954254 0.8026571273803711
+0.5898752808570862 0.001090922742150724 0.010731000453233719 0.2485869973897934 0.7340509295463562
+0.5115278363227844 0.0009532441035844386 0.006692926865071058 0.25134336948394775 0.6504339575767517
+0.5730191469192505 0.0016657032538205385 0.004983488004654646 0.21942119300365448 1.0291329622268677
+0.5165632963180542 0.0009982931660488248 0.005933516658842564 0.22656658291816711 0.9659519791603088
+0.5645551085472107 0.003244801890105009 0.013987251557409763 0.35046151280403137 0.7315282821655273
+0.5281928777694702 0.0009336845832876861 0.008670509792864323 0.3629915416240692 0.49767357110977173
+0.5855140686035156 0.005042745731770992 0.008621624670922756 0.19180825352668762 1.0299079418182373
+0.5260839462280273 0.001826689695008099 0.006095973774790764 0.1690341681241989 0.9437707662582397
+0.5213824510574341 0.005060737486928701 0.013183333911001682 0.2738635241985321 0.8288967609405518
+0.5190023183822632 0.0069579207338392735 0.023381009697914124 0.3496704399585724 0.6173680424690247
+0.502785861492157 0.0015577635494992137 0.007527241483330727 0.1580873727798462 0.692627489566803
+0.5627894401550293 0.011288261972367764 0.013051805086433887 0.1621076613664627 0.9510571956634521
+0.5779435634613037 0.00317537784576416 0.007576273288577795 0.26674532890319824 0.9240569472312927
+0.5173447728157043 0.0038651616778224707 0.013690714724361897 0.24590808153152466 0.7023951411247253
+0.4955238699913025 0.006437760312110186 0.01247659046202898 0.23623929917812347 0.719595193862915
+0.5154111385345459 0.0008357534534297884 0.007062351796776056 0.28780630230903625 0.7090986967086792
+0.5475799441337585 0.002039910526946187 0.013028590939939022 0.1582873910665512 0.9980160593986511
+0.49714717268943787 0.0008153783855959773 0.007085843943059444 0.23193584382534027 0.7131538987159729
+0.49127471446990967 0.004052083473652601 0.009693106636404991 0.23300914466381073 0.695970356464386
+0.5472708344459534 0.002534568775445223 0.009360674768686295 0.26022106409072876 0.8220299482345581
+0.49115464091300964 0.0012586837401613593 0.006820484530180693 0.22901040315628052 0.7215952277183533
+0.5147364735603333 0.0031448164954781532 0.007851817645132542 0.17417918145656586 0.8579376935958862
+0.4827832579612732 0.0011898897355422378 0.0070521412417292595 0.23238328099250793 0.7370695471763611
+0.5199841260910034 0.0012727960711345077 0.010178850032389164 0.21825188398361206 0.7337210178375244
+0.4782258868217468 0.0011086083250120282 0.008465361781418324 0.18740816414356232 0.7202025651931763
+0.49507707357406616 0.0010846334043890238 0.005267253145575523 0.16934944689273834 0.8788295388221741
+0.525898277759552 0.001174055621959269 0.005444320850074291 0.23092438280582428 0.8024521470069885
+0.5430920720100403 0.0007876339950598776 0.006176460534334183 0.34233662486076355 0.681520402431488
+0.47884172201156616 0.0009621918434277177 0.006440795958042145 0.18400059640407562 0.704879105091095
+0.493574321269989 0.0008209381485357881 0.004864330403506756 0.2254946231842041 0.8140129446983337
+0.4847490191459656 0.0009239953360520303 0.005747813731431961 0.1864251345396042 0.7730098366737366
+0.49708423018455505 0.0009074495173990726 0.007489460986107588 0.2611905336380005 0.6809148788452148
+0.4639478623867035 0.0010423229541629553 0.007332332897931337 0.15429890155792236 0.6445390582084656
+0.49014532566070557 0.0019507200922816992 0.00557706318795681 0.15452851355075836 0.7734361290931702
+0.5174160599708557 0.0012475393014028668 0.006080745253711939 0.16201373934745789 0.81926429271698
+0.5228357315063477 0.0009611058630980551 0.00735953263938427 0.2156829535961151 0.7834831476211548
+0.4778749942779541 0.001057894784025848 0.007437646854668856 0.19306743144989014 0.6974846720695496
+0.5249320268630981 0.0007721612346358597 0.005942866671830416 0.2867347300052643 0.7982074618339539
+0.48106619715690613 0.0007654348155483603 0.006090397480875254 0.16804665327072144 0.7736316919326782
+0.4688223600387573 0.0012725883862003684 0.006032741628587246 0.24750587344169617 0.7527937293052673
+0.5022585391998291 0.0020867527928203344 0.007336774375289679 0.17710120975971222 0.7587610483169556
+0.49728089570999146 0.0010142645332962275 0.007778870407491922 0.23980967700481415 0.7216024398803711
+0.4745916426181793 0.0012855774257332087 0.007243965286761522 0.2009991854429245 0.6774366497993469
+0.5355476140975952 0.0019663823768496513 0.00802823156118393 0.2659943997859955 0.8982519507408142
+0.4699764549732208 0.0008366825059056282 0.007004145532846451 0.21788929402828217 0.6544639468193054
+0.4598377048969269 0.0015453362138941884 0.007193263154476881 0.25595012307167053 0.6850556135177612
+0.5036166906356812 0.0011321118799969554 0.005603731609880924 0.19434230029582977 0.7733917832374573
+0.48398905992507935 0.0008185648475773633 0.005501690786331892 0.16364158689975739 0.763350248336792
+0.496872216463089 0.0007363377953879535 0.006083081476390362 0.26406702399253845 0.6811496615409851
+0.46311599016189575 0.0008267617668025196 0.006037089508026838 0.21560783684253693 0.7986249923706055
+0.5073238611221313 0.0008141687139868736 0.006673259660601616 0.17514020204544067 0.739499568939209
+0.5095674991607666 0.000901196151971817 0.00626360485330224 0.25653132796287537 0.6490381956100464
+0.5136709213256836 0.0009529921808280051 0.005045901518315077 0.1265455186367035 0.8448100090026855
+0.5491091012954712 0.0009894651593640447 0.004278948530554771 0.25178423523902893 0.9084256887435913
+0.46953362226486206 0.003612814238294959 0.008541565388441086 0.22483250498771667 0.7200849652290344
+0.4680880308151245 0.0018747120629996061 0.010948702692985535 0.2461474984884262 0.6054183840751648
+0.5586756467819214 0.0023905213456600904 0.005828074179589748 0.25589022040367126 0.799949586391449
+0.49196767807006836 0.004704405553638935 0.0065996767953038216 0.17487674951553345 0.8436238765716553
+0.5475641489028931 0.001289229141548276 0.007516508921980858 0.23796331882476807 0.9700631499290466
+0.5169028639793396 0.0021753984037786722 0.017615489661693573 0.2673972249031067 0.7757568955421448
+0.5025489926338196 0.0009218254708684981 0.010015169158577919 0.31187739968299866 0.6326962113380432
+0.5484866499900818 0.0029711059760302305 0.00836756732314825 0.2501302659511566 0.7919797897338867
+0.4872688055038452 0.001681232824921608 0.009105871431529522 0.1669255495071411 0.7769342064857483
+0.5216150879859924 0.003874759655445814 0.004773884546011686 0.17494219541549683 0.8506420254707336
+0.5095841288566589 0.012147607281804085 0.01364991720765829 0.27666255831718445 0.8644007444381714
+0.5063281655311584 0.0028672327753156424 0.011051127687096596 0.24333807826042175 0.6969410181045532
+0.47499993443489075 0.006293838378041983 0.00972029659897089 0.2664041221141815 0.5821365118026733
+0.4941364526748657 0.013311207294464111 0.017556436359882355 0.19910143315792084 0.7797845005989075
+0.49485278129577637 0.0008413963369093835 0.0053519899956882 0.17861934006214142 0.847953736782074
+0.5177962779998779 0.013580851256847382 0.02033558301627636 0.33681488037109375 0.8994491696357727
+0.4520183801651001 0.010547256097197533 0.020048733800649643 0.3325119912624359 0.5492911338806152
+0.4712396264076233 0.002077401615679264 0.0058258408680558205 0.21355237066745758 0.5480097532272339
+0.5486184358596802 0.013886314816772938 0.020828774198889732 0.11896064132452011 0.9769544005393982
+0.5092920660972595 0.0010628405725583434 0.00547796068713069 0.14542517066001892 0.9054144024848938
+0.4840725362300873 0.011951607652008533 0.03149799257516861 0.2872231602668762 0.6895146369934082
+0.4866010844707489 0.006448219995945692 0.02449141815304756 0.29697275161743164 0.6347904205322266
+0.48992419242858887 0.0051676747389137745 0.0072266417555511 0.16371460258960724 0.8245095610618591
+0.476931631565094 0.014203490689396858 0.01801428571343422 0.12190931290388107 0.7413856983184814
+0.4604817032814026 0.0008555633248761296 0.005645057186484337 0.14532947540283203 0.7519883513450623
+0.49402058124542236 0.012387825176119804 0.021764425560832024 0.2996123731136322 0.6225931644439697
+0.4380866289138794 0.0091248182579875 0.016534194350242615 0.23590289056301117 0.6052095293998718
+0.4576042592525482 0.0013651831541210413 0.0076032658107578754 0.18636921048164368 0.7533383369445801
+0.4856141209602356 0.00775785930454731 0.015376977622509003 0.13967390358448029 0.7405175566673279
+0.47295477986335754 0.0013789617223665118 0.006382989697158337 0.19618730247020721 0.6523810029029846
+0.5048075914382935 0.0046187532134354115 0.01183148380368948 0.2984234690666199 0.7824589610099792
+0.4869205355644226 0.006136718206107616 0.013202953152358532 0.25352802872657776 0.7240977883338928
+0.504144012928009 0.0007393805426545441 0.005662006326019764 0.26453834772109985 0.7995771765708923
+0.4782702922821045 0.004626747686415911 0.008962835185229778 0.16230159997940063 0.8019737601280212
+0.4622536897659302 0.0025760596618056297 0.006244350224733353 0.19158075749874115 0.6450080871582031
+0.4693972170352936 0.0011135560926049948 0.005227627232670784 0.2403329461812973 0.7295299768447876
+0.4872197210788727 0.004069734830409288 0.009142269380390644 0.2276424765586853 0.7037047743797302
+0.42937010526657104 0.0018348037265241146 0.005627671722322702 0.21915774047374725 0.6576657891273499
+0.43556588888168335 0.0009745100978761911 0.004804776981472969 0.18548405170440674 0.7211376428604126
+0.47492116689682007 0.0018085307674482465 0.005222374107688665 0.17324046790599823 0.6264954209327698
+0.5006510019302368 0.0009200929780490696 0.0057566799223423 0.16126780211925507 0.7601082921028137
+0.45509806275367737 0.0010493663139641285 0.009631057269871235 0.1895526498556137 0.6872947216033936
+0.505086362361908 0.0008977064280770719 0.00818242784589529 0.20979999005794525 0.8751810193061829
+0.4738406538963318 0.0011228895746171474 0.0072822547517716885 0.16873903572559357 0.8001007437705994
+0.4568619728088379 0.0016535931499674916 0.005914532579481602 0.22099290788173676 0.550554633140564
+0.4421199858188629 0.001302712014876306 0.005215664394199848 0.17069929838180542 0.5427533984184265
+0.5354677438735962 0.0008067582384683192 0.004376892000436783 0.13063856959342957 0.9734863638877869
+0.497821569442749 0.003976357169449329 0.007926978170871735 0.15542751550674438 0.8123630285263062
+0.5388950109481812 0.0035679233260452747 0.011674313805997372 0.21819926798343658 0.8146265745162964
+0.4185383915901184 0.0008498964598402381 0.007787816226482391 0.18498583137989044 0.5862346291542053
+0.46549296379089355 0.00565000344067812 0.012204025872051716 0.2048387974500656 0.6247716546058655
+0.5033352375030518 0.002778546418994665 0.010534851811826229 0.28490906953811646 0.7399309277534485
+0.484123557806015 0.0031388390343636274 0.0055457972921431065 0.1589331328868866 0.835616946220398
+0.4628180265426636 0.014359839260578156 0.01772700808942318 0.22805503010749817 0.7076373100280762
+0.4660452604293823 0.0050162822008132935 0.013686219230294228 0.27281197905540466 0.6691707372665405
+0.5157197117805481 0.0036851675249636173 0.019362635910511017 0.2432209551334381 0.8205599188804626
+0.4696601927280426 0.012519449926912785 0.0095074363052845 0.14820043742656708 0.7674536108970642
+0.48908525705337524 0.003785428125411272 0.006559157278388739 0.25477099418640137 0.6920903325080872
+0.5540274381637573 0.0017577734543010592 0.012435010634362698 0.23806411027908325 0.7519288659095764
+0.49577245116233826 0.00410922896116972 0.010645545087754726 0.16375719010829926 0.796369194984436
+0.44187992811203003 0.0013060496421530843 0.004748817533254623 0.15295575559139252 0.7143526673316956
+0.47599563002586365 0.0008323526126332581 0.004978914279490709 0.1325908899307251 0.7562129497528076
+0.4260168671607971 0.0009374653454869986 0.004920352715998888 0.16336841881275177 0.5630993247032166
+0.47714221477508545 0.0007442696369253099 0.00592682184651494 0.21671034395694733 0.6462430357933044
+0.4844930171966553 0.0007973809842951596 0.006294551771134138 0.18713581562042236 0.7464593648910522
+0.5126148462295532 0.0009314344497397542 0.006214124150574207 0.1545909345149994 0.9186234474182129
+0.48967963457107544 0.0007373489206656814 0.00645284540951252 0.20058687031269073 0.7558043599128723
+0.48135972023010254 0.0007609523017890751 0.006001223344355822 0.1598956435918808 0.7235348224639893
+0.4945269227027893 0.0007796426652930677 0.0056257182732224464 0.2036208212375641 0.721354067325592
+0.43557417392730713 0.0007891811546869576 0.005244802683591843 0.17970915138721466 0.6233372092247009
+0.4796810746192932 0.0011300166370347142 0.00498929712921381 0.19455471634864807 0.7402910590171814
+0.46967029571533203 0.0017020611558109522 0.006869551260024309 0.21132981777191162 0.6896390318870544
+0.4444027841091156 0.0009936830028891563 0.006876486819237471 0.24862541258335114 0.6789126396179199
+0.506624162197113 0.0012237596092745662 0.005851239897310734 0.18370002508163452 0.7144790887832642
+0.47655153274536133 0.0021653976291418076 0.005901829339563847 0.14393168687820435 0.7087792158126831
+0.4652444124221802 0.0011437308276072145 0.005329997278749943 0.2257554680109024 0.6962966322898865
+0.45953065156936646 0.001011005020700395 0.004887489136308432 0.19362491369247437 0.6811630129814148
+0.4797801971435547 0.0024143229238688946 0.005747541785240173 0.2175094187259674 0.7234786748886108
+0.5044395923614502 0.0017520730616524816 0.005409899167716503 0.3185320496559143 0.7717908620834351
+0.49657195806503296 0.0007358453003689647 0.00513853644952178 0.21274538338184357 0.7416636347770691
+0.46932002902030945 0.0013843062333762646 0.004978145007044077 0.18529574573040009 0.608773410320282
+0.4972660541534424 0.00130931637249887 0.0050114733166992664 0.3302238881587982 0.7192703485488892
+0.4468075633049011 0.0007712342194281518 0.005696737673133612 0.23213262856006622 0.6976437568664551
+0.44745922088623047 0.0009711462189443409 0.006959979422390461 0.14058247208595276 0.7268856167793274
+0.45088714361190796 0.0008071225020103157 0.006309763994067907 0.19138677418231964 0.6069051623344421
+0.48489320278167725 0.0010909573175013065 0.004795796703547239 0.20218126475811005 0.5855459570884705
+0.497435986995697 0.001139750354923308 0.00538153201341629 0.222996324300766 0.7854166626930237
+0.4613885283470154 0.000881392857991159 0.0050802757032215595 0.11635623872280121 0.8525947332382202
+0.48074740171432495 0.0030587329529225826 0.008625658228993416 0.22847801446914673 0.7422481179237366
+0.5017309188842773 0.0010512865846976638 0.007826919667422771 0.3103494644165039 0.7159790396690369
+0.47559210658073425 0.002225178061053157 0.007847888395190239 0.20207253098487854 0.715442955493927
+0.45869874954223633 0.002650207607075572 0.011054056696593761 0.2297445386648178 0.6816583871841431
+0.46131837368011475 0.0011174386600032449 0.004536771681159735 0.19254422187805176 0.7102041840553284
+0.46950018405914307 0.00909801572561264 0.008909622207283974 0.22661250829696655 0.6933172941207886
+0.4214247763156891 0.006497449707239866 0.011063500307500362 0.17179220914840698 0.6427133679389954
+0.5023917555809021 0.0008661659085191786 0.005274754483252764 0.3029821217060089 0.8221216797828674
+0.492880642414093 0.0072615970857441425 0.0060391235165297985 0.18684644997119904 0.7928650379180908
+0.4788397550582886 0.007545360829681158 0.006701311562210321 0.21570608019828796 0.6431915163993835
+0.5098313093185425 0.002434398513287306 0.008320926688611507 0.3627009689807892 0.6174067854881287
+0.4589118957519531 0.0009034085087478161 0.007258423138409853 0.1601763665676117 0.7657126784324646
+0.4611167907714844 0.0028270042967051268 0.006795988883823156 0.24303993582725525 0.7650065422058105
+0.5075904130935669 0.003375949105247855 0.007823565043509007 0.13777266442775726 0.8919806480407715
+0.4702203869819641 0.0010465772356837988 0.00821524765342474 0.19802263379096985 0.6487798094749451
+0.4935121536254883 0.0020886841230094433 0.0074331266805529594 0.2631498873233795 0.5883614420890808
+0.43444395065307617 0.0037425407208502293 0.009863186627626419 0.13010160624980927 0.6109282374382019
+0.4761773347854614 0.0006031439988873899 0.005608437117189169 0.12078994512557983 0.8234469294548035
+0.4783487319946289 0.006665131077170372 0.009637895971536636 0.1971752941608429 0.8549257516860962
+0.45105767250061035 0.007580564357340336 0.016658006235957146 0.21462593972682953 0.6012320518493652
+0.4734381437301636 0.0007374439155682921 0.006639409344643354 0.21869687736034393 0.5862569212913513
+0.4246778190135956 0.009040635079145432 0.014419491402804852 0.13083882629871368 0.5642983913421631
+0.4521505534648895 0.0035676041152328253 0.008451257832348347 0.13652069866657257 0.7713356018066406
+0.46347418427467346 0.004474471788853407 0.012104629538953304 0.25739502906799316 0.7135206460952759
+0.4616909921169281 0.011721584014594555 0.025605197995901108 0.2508440911769867 0.654846727848053
+0.4577261805534363 0.0011829850263893604 0.00920102745294571 0.307209849357605 0.564282238483429
+0.46868836879730225 0.00884971208870411 0.010750981979072094 0.1627051830291748 0.748320996761322
+0.5073087215423584 0.006765571888536215 0.00895005278289318 0.1301511973142624 0.9906520247459412
+0.46549052000045776 0.001688139047473669 0.013272235170006752 0.31725040078163147 0.6856728196144104
+0.4824959635734558 0.005733713507652283 0.020579872652888298 0.31106051802635193 0.5915639400482178
+0.4468582272529602 0.0006918239523656666 0.0051561142317950726 0.1725844442844391 0.5185928344726562
+0.45560476183891296 0.0036388766020536423 0.01580369658768177 0.1718340963125229 0.7252013087272644
+0.5039176344871521 0.0006632275762967765 0.00647424953058362 0.12166933715343475 0.9065179824829102
+0.4416295289993286 0.009858477860689163 0.013808031566441059 0.21992850303649902 0.6979712843894958
+0.46463796496391296 0.008464218117296696 0.024323586374521255 0.3152313232421875 0.5894766449928284
+0.4605978727340698 0.002759055932983756 0.005415565334260464 0.2157171666622162 0.6088176369667053
+0.463150292634964 0.018047930672764778 0.020319603383541107 0.1346299648284912 0.7771103978157043
+0.45447903871536255 0.003377012675628066 0.0066894302144646645 0.169127956032753 0.7769805788993835
+0.47139111161231995 0.009494817815721035 0.019639529287815094 0.1754053682088852 0.6981686949729919
+0.4595034122467041 0.015959206968545914 0.033623505383729935 0.23693209886550903 0.622944712638855
+0.439980685710907 0.000796811596956104 0.007147158496081829 0.22038672864437103 0.594070553779602
+0.4086647033691406 0.013295812532305717 0.014341839589178562 0.1130317971110344 0.6115342378616333
+0.46499431133270264 0.009335474111139774 0.011730819940567017 0.1843651682138443 0.8038354516029358
+0.40187394618988037 0.0016200188547372818 0.008556694723665714 0.1768599897623062 0.6647997498512268
+0.4428942799568176 0.010642306879162788 0.020922277122735977 0.28426411747932434 0.5512857437133789
+0.4813539981842041 0.002748518716543913 0.006223410367965698 0.3022688031196594 0.6577574014663696
+0.4697670042514801 0.0023691540118306875 0.01255676057189703 0.12256406992673874 0.7824207544326782
+0.45119136571884155 0.0023792057763785124 0.012976507656276226 0.15429896116256714 0.7468892931938171
+0.46751630306243896 0.0021494878455996513 0.006225673481822014 0.20959241688251495 0.666170597076416
+0.4394652545452118 0.007595413364470005 0.016127431765198708 0.4175778031349182 0.5206870436668396
+0.4741488993167877 0.0018842729041352868 0.00868831668049097 0.19868186116218567 0.8104978799819946
+0.44355297088623047 0.003069357480853796 0.005453940946608782 0.12463749200105667 0.6982443928718567
+0.4092077910900116 0.005856199190020561 0.005513058044016361 0.14675141870975494 0.6522248983383179
+0.4116949439048767 0.001586084719747305 0.005456491839140654 0.19625972211360931 0.6112507581710815
+0.4246901273727417 0.001945348223671317 0.012420210987329483 0.26286187767982483 0.5439504981040955
+0.4669681489467621 0.0017788823461160064 0.01017887145280838 0.24663951992988586 0.6487995386123657
+0.4296599328517914 0.0010416620643809438 0.004505017306655645 0.2379056066274643 0.6480744481086731
+0.47956153750419617 0.002185665536671877 0.00464646378532052 0.11823280155658722 0.8640051484107971
+0.47403156757354736 0.0012195464223623276 0.004492436069995165 0.22700028121471405 0.7204605340957642
+0.4760505259037018 0.0008208141662180424 0.005964568350464106 0.26291969418525696 0.6049157381057739
+0.45194289088249207 0.001015354529954493 0.004678224213421345 0.2051498144865036 0.6531049013137817
+0.47698765993118286 0.0007298646378330886 0.005111690144985914 0.14665307104587555 0.7326868176460266
+0.39795321226119995 0.0011944308644160628 0.005126813892275095 0.21110203862190247 0.627784013748169
+0.44177767634391785 0.0030030841007828712 0.0037029297091066837 0.18536704778671265 0.7182231545448303
+0.46627646684646606 0.0034957430325448513 0.008637897670269012 0.20918704569339752 0.6604366898536682
+0.4506126940250397 0.0009213724406436086 0.007615764159709215 0.17032895982265472 0.5709268450737
+0.4865127205848694 0.005029088817536831 0.005463093053549528 0.21667630970478058 0.7159169912338257
+0.41900137066841125 0.007119013927876949 0.0051273140124976635 0.15319719910621643 0.6857679486274719
+0.43448957800865173 0.0013299244455993176 0.006398639641702175 0.15712428092956543 0.7667754292488098
+0.4293963313102722 0.0019191950559616089 0.011730639263987541 0.17930781841278076 0.5514909029006958
+0.4785860776901245 0.0017598429694771767 0.007315556053072214 0.269386351108551 0.6666773557662964
+0.44185322523117065 0.000625384971499443 0.004015292506664991 0.14019900560379028 0.6853488087654114
+0.4457603096961975 0.0008071414777077734 0.004654363263398409 0.1224256232380867 0.7388294339179993
+0.42833957076072693 0.0008585074101574719 0.005016150884330273 0.1485302448272705 0.6573949456214905
+0.45357051491737366 0.0016756892437115312 0.007687451783567667 0.2557287812232971 0.5850702524185181
+0.42253538966178894 0.0007295117247849703 0.005851286929100752 0.18762625753879547 0.6689760684967041
+0.4763101041316986 0.001222863793373108 0.005834463983774185 0.23055632412433624 0.8114944100379944
+0.43298083543777466 0.0009798851097002625 0.005966744385659695 0.17528307437896729 0.7099570631980896
+0.45800718665122986 0.000834575854241848 0.006560661364346743 0.31839612126350403 0.5303618311882019
+0.44551536440849304 0.001087404671125114 0.006020484026521444 0.15351825952529907 0.629335880279541
+0.47260287404060364 0.0009246015106327832 0.005613191518932581 0.14060448110103607 0.8271851539611816
+0.463400661945343 0.0006411272333934903 0.005852679256349802 0.16605715453624725 0.8388072848320007
+0.40312299132347107 0.0006952009280212224 0.007373136002570391 0.1767524927854538 0.5922203660011292
+0.427929162979126 0.0011683093616738915 0.006507310084998608 0.26636552810668945 0.560436487197876
+0.456468403339386 0.0017957231029868126 0.006844324059784412 0.21412138640880585 0.6151395440101624
+0.44624677300453186 0.0005312350112944841 0.0058635538443923 0.17551633715629578 0.666233241558075
+0.5165451169013977 0.0035546254366636276 0.0038193752989172935 0.13803508877754211 0.9592968225479126
+0.43819886445999146 0.008978630416095257 0.011151399463415146 0.19663551449775696 0.6718003749847412
+0.441683292388916 0.0023578894324600697 0.010087264701724052 0.16431309282779694 0.5080702900886536
+0.4103125035762787 0.004257073625922203 0.00585534842684865 0.22510288655757904 0.5159546732902527
+0.45108896493911743 0.011575561948120594 0.01008404791355133 0.14927469193935394 0.7587457299232483
+0.39419659972190857 0.001976407365873456 0.005013235379010439 0.1339208334684372 0.7046734094619751
+0.3889654278755188 0.004333930090069771 0.015645796433091164 0.2693072259426117 0.5439144372940063
+0.4638195335865021 0.004969990812242031 0.013801210559904575 0.24545440077781677 0.6068112850189209
+0.4398439824581146 0.0006494777626357973 0.00437796488404274 0.25704294443130493 0.5135030746459961
+0.44349756836891174 0.003348607337102294 0.009271930903196335 0.1991293579339981 0.6753153800964355
+0.4623958468437195 0.0008008646545931697 0.0043210540898144245 0.11543582379817963 0.871854841709137
+0.43834635615348816 0.004276788327842951 0.01719503104686737 0.18353034555912018 0.7122299075126648
+0.41507208347320557 0.0031102432403713465 0.02046443335711956 0.2690010964870453 0.45982274413108826
+0.40691646933555603 0.0028005915228277445 0.006736878305673599 0.18760909140110016 0.5985591411590576
+0.4457789361476898 0.009150130674242973 0.017294909805059433 0.20076331496238708 0.6411207318305969
+0.47771236300468445 0.0009204268571920693 0.006731647532433271 0.12763075530529022 0.8080651760101318
+0.45107170939445496 0.010702481493353844 0.010712976567447186 0.18225443363189697 0.6338199973106384
+0.4484401345252991 0.015598793514072895 0.02112863026559353 0.25716060400009155 0.6353996396064758
+0.4030577838420868 0.0011987262405455112 0.0071945795789361 0.16408827900886536 0.6034305691719055
+0.41858989000320435 0.009536021389067173 0.010946916416287422 0.25884944200515747 0.6644971966743469
+0.4177763760089874 0.011516215279698372 0.012270975857973099 0.16799324750900269 0.6584775447845459
+0.43498602509498596 0.000840242428239435 0.006432653870433569 0.1539316177368164 0.6086443066596985
+0.4058675765991211 0.005546347703784704 0.012080268003046513 0.16984735429286957 0.5979027152061462
+0.4682970643043518 0.005655169021338224 0.007954542525112629 0.22385002672672272 0.7136090397834778
+0.4298030138015747 0.0008242715266533196 0.004054277669638395 0.12216580659151077 0.7027520537376404
+0.4490552842617035 0.0009924692567437887 0.0056541976518929005 0.11264687776565552 0.7799272537231445
+0.4026467800140381 0.0007482774672098458 0.006160496734082699 0.2138465940952301 0.589530348777771
+0.43208134174346924 0.0007733392994850874 0.010076218284666538 0.25919002294540405 0.591018557548523
+0.42688828706741333 0.0010257669491693377 0.007177908904850483 0.22019825875759125 0.5915098190307617
+0.388797789812088 0.0020975982770323753 0.004441822879016399 0.09254787862300873 0.6103758215904236
+0.40640759468078613 0.0010858413297683 0.003996751271188259 0.1059659868478775 0.6604105830192566
+0.45013606548309326 0.00141547538805753 0.005270882043987513 0.12104158848524094 0.7663363814353943
+0.4137699007987976 0.0035665961913764477 0.010932452976703644 0.17383895814418793 0.6124061942100525
+0.4257180690765381 0.0009453560342080891 0.007223447319120169 0.1649145781993866 0.5807569622993469
+0.41319626569747925 0.0026089351158589125 0.005880445707589388 0.19944235682487488 0.5911976099014282
+0.4435402452945709 0.003985760733485222 0.00935433991253376 0.1502569615840912 0.731410801410675
+0.43358922004699707 0.0006671463488601148 0.005141690839082003 0.13919496536254883 0.6521843671798706
+0.4165436625480652 0.004310247488319874 0.0068842270411551 0.1901981383562088 0.5872876644134521
+0.42844805121421814 0.004744178149849176 0.006906678434461355 0.23309184610843658 0.5962748527526855
+0.3669070601463318 0.0008400012156926095 0.0037643490359187126 0.12734617292881012 0.6029011011123657
+0.3793485164642334 0.001471943804062903 0.004572445061057806 0.14696359634399414 0.5928006768226624
+0.43195033073425293 0.0015629176050424576 0.004284224472939968 0.20073792338371277 0.6163463592529297
+0.4172697365283966 0.0006600331398658454 0.006146550644189119 0.13008835911750793 0.5951653718948364
+0.42623794078826904 0.0007871647831052542 0.008821554481983185 0.18411441147327423 0.6217836141586304
+0.4316340684890747 0.0008443306433036923 0.005098294001072645 0.17570027709007263 0.6747812628746033
+0.4114747643470764 0.001114137121476233 0.004342076368629932 0.1952752023935318 0.7113766670227051
+0.4006328284740448 0.0008198711439035833 0.005546742118895054 0.20701885223388672 0.5850439667701721
+0.4415491223335266 0.0007533960742875934 0.005646026227623224 0.27345141768455505 0.6373819708824158
+0.45469099283218384 0.0007903514197096229 0.004936445504426956 0.22333917021751404 0.717083215713501
+0.3968217968940735 0.0006448497879318893 0.0048552220687270164 0.17349858582019806 0.5931052565574646
+0.4055040180683136 0.0006582874339073896 0.004333429504185915 0.26866355538368225 0.5996370315551758
+0.43051600456237793 0.0008966190507635474 0.0043706344440579414 0.2973853349685669 0.5755952000617981
+0.40549373626708984 0.0010916703613474965 0.004045742563903332 0.13392144441604614 0.6209725141525269
+0.4258652925491333 0.0011476228246465325 0.006133385468274355 0.14109882712364197 0.7572489380836487
+0.4371739327907562 0.0007846865337342024 0.007833288982510567 0.1727711260318756 0.6750352382659912
+0.4255538880825043 0.0017170332139357924 0.006192405708134174 0.1858053356409073 0.5979868173599243
+0.41886794567108154 0.003492131596431136 0.006118918769061565 0.13266974687576294 0.5819666981697083
+0.47198551893234253 0.0014582113362848759 0.005023791920393705 0.0810389518737793 0.839065670967102
+0.4337400197982788 0.0020155631937086582 0.004280561115592718 0.166450634598732 0.6599277257919312
+0.3979232907295227 0.005608577746897936 0.006313923746347427 0.22417771816253662 0.5523231029510498
+0.4120914041996002 0.002529792720451951 0.005648199934512377 0.19727560877799988 0.6572379469871521
+0.40567249059677124 0.0007271890062838793 0.004262316972017288 0.24064524471759796 0.6470578908920288
+0.4419449269771576 0.0016615785425528884 0.0047027599066495895 0.16939567029476166 0.593382716178894
+0.4140852093696594 0.0012824650621041656 0.005962037481367588 0.11607944965362549 0.4971284568309784
+0.4233933091163635 0.0007872665883041918 0.007345297839492559 0.30761685967445374 0.6315985918045044
+0.3857998847961426 0.0008831270970404148 0.006352861411869526 0.24404112994670868 0.6501202583312988
+0.3976293206214905 0.0007485680980607867 0.006710420828312635 0.2008805274963379 0.6393155455589294
+0.454532653093338 0.000948832486756146 0.005026803817600012 0.1644153594970703 0.6970148086547852
+0.44037559628486633 0.0009918948635458946 0.0044892532750964165 0.2267671525478363 0.5003777146339417
+0.4170044958591461 0.000685044564306736 0.004340864717960358 0.14036597311496735 0.6669979095458984
+0.43260759115219116 0.0009968999074772 0.004501971881836653 0.09660423547029495 0.805479109287262
+0.430500328540802 0.002036050660535693 0.007425840478390455 0.1287892907857895 0.7402792572975159
+0.4115169644355774 0.0008245578501373529 0.0074132392182946205 0.20753420889377594 0.5692099928855896
+0.43734607100486755 0.0017305142246186733 0.0071724713779985905 0.2082086205482483 0.556268572807312
+0.46214497089385986 0.0026953492779284716 0.011912129819393158 0.1670667827129364 0.706014096736908
+0.41268759965896606 0.0007361852913163602 0.005349643528461456 0.14551106095314026 0.7249552011489868
+0.4097630977630615 0.007823568768799305 0.006995814386755228 0.15297158062458038 0.6728386282920837
+0.43741703033447266 0.008197790011763573 0.012462352402508259 0.2307959347963333 0.6891575455665588
+0.3990521728992462 0.020542368292808533 0.006469571962952614 0.21686430275440216 0.49835410714149475
+0.43690216541290283 0.007474059239029884 0.006835822947323322 0.19705574214458466 0.5625074505805969
+0.39823514223098755 0.0075992620550096035 0.005938537884503603 0.1750141978263855 0.5609212517738342
+0.45283353328704834 0.0009545068605802953 0.007524420507252216 0.18641318380832672 0.812333881855011
+0.44840681552886963 0.005938120186328888 0.026302995160222054 0.2740974426269531 0.6540762782096863
+0.4496174454689026 0.0022286921739578247 0.016801662743091583 0.16366912424564362 0.5745804309844971
+0.42321455478668213 0.003589129075407982 0.004869562108069658 0.14923007786273956 0.6726296544075012
+0.43105757236480713 0.007710175588726997 0.011721477843821049 0.161519393324852 0.8018013834953308
+0.4161107540130615 0.0006389457848854363 0.004676212556660175 0.2805297374725342 0.7207099199295044
+0.4160456657409668 0.006003416143357754 0.009402153082191944 0.3638429343700409 0.4591243267059326
+0.4155438542366028 0.005384758580476046 0.004661516286432743 0.18202976882457733 0.46242862939834595
+0.4532266855239868 0.0006809373735450208 0.007343207485973835 0.14060775935649872 0.737722635269165
+0.48194462060928345 0.0005540085257962346 0.00767787080258131 0.19462329149246216 0.8297021389007568
+0.401314914226532 0.001603217446245253 0.005690445192158222 0.1573467254638672 0.6749037504196167
+0.42368265986442566 0.0028484484646469355 0.01926959678530693 0.22029876708984375 0.5167274475097656
+0.40014180541038513 0.0013046704698354006 0.009601823054254055 0.24779485166072845 0.5539101362228394
+0.3890220522880554 0.009368169121444225 0.006727817002683878 0.12169992923736572 0.5971701741218567
+0.41393953561782837 0.005664631258696318 0.005820560269057751 0.09608376771211624 0.7287933230400085
+0.42030835151672363 0.0015283440006896853 0.006023585330694914 0.14155547320842743 0.6568499207496643
+0.385957270860672 0.009093555621802807 0.018425343558192253 0.14048993587493896 0.5290923118591309
+0.38326436281204224 0.0035160393454134464 0.010771429166197777 0.11328558623790741 0.5339773297309875
+0.3702490031719208 0.0020964713767170906 0.004871277138590813 0.15807686746120453 0.5242477059364319
+0.3771166205406189 0.007164258975535631 0.009967396967113018 0.1479993313550949 0.6129216551780701
+0.43015986680984497 0.0016986937262117863 0.006074936594814062 0.12019592523574829 0.7149466872215271
+0.4022238254547119 0.0025867794174700975 0.009356009773910046 0.26479989290237427 0.5734118819236755
+0.3818894624710083 0.003727919887751341 0.007758487015962601 0.23326337337493896 0.49358612298965454
+0.3934522271156311 0.0007310439832508564 0.004645303823053837 0.12853381037712097 0.5498183965682983
+0.3840161859989166 0.0007641376578249037 0.008144578896462917 0.14985083043575287 0.605265200138092
+0.3979424834251404 0.0008080817642621696 0.003680937457829714 0.1382560133934021 0.6986003518104553
+0.40222349762916565 0.0032537623774260283 0.010242726653814316 0.16705572605133057 0.5712864398956299
+0.4138050675392151 0.001345892553217709 0.01359080895781517 0.21935316920280457 0.5586656928062439
+0.4312012195587158 0.0029791409615427256 0.0061542498879134655 0.16772015392780304 0.5814857482910156
+0.44830048084259033 0.008316039107739925 0.007568941451609135 0.10962574928998947 0.7893681526184082
+0.3919224143028259 0.0016671852208673954 0.004681074060499668 0.09773905575275421 0.6634390950202942
+0.41964131593704224 0.0034512251149863005 0.012081853114068508 0.18048246204853058 0.5934293270111084
+0.42315375804901123 0.005343726836144924 0.012892314232885838 0.2755239009857178 0.5607163906097412
+0.40675580501556396 0.0009359159739688039 0.0039809029549360275 0.1213616281747818 0.5355460047721863
+0.4155401587486267 0.0029269487131386995 0.00810465682297945 0.09157856553792953 0.6910236477851868
+0.43216434121131897 0.001571563770994544 0.005213200580328703 0.10109487175941467 0.7200274467468262
+0.40162110328674316 0.0015224877279251814 0.0064756241627037525 0.14881719648838043 0.6154745817184448
+0.4246481657028198 0.003074715379625559 0.010719112120568752 0.21296818554401398 0.5809308886528015
+0.38607361912727356 0.0007054316811263561 0.005257423967123032 0.13271397352218628 0.6357612013816833
+0.4029538929462433 0.0027740243822336197 0.004722914192825556 0.11864388734102249 0.6423448920249939
+0.3778095841407776 0.0020988122560083866 0.004296186845749617 0.11509207636117935 0.5891531109809875
+0.3882005214691162 0.0007971262675710022 0.005680525675415993 0.24053847789764404 0.6134719848632812
+0.437467485666275 0.002824644558131695 0.010003827512264252 0.21875226497650146 0.6155802011489868
+0.37269654870033264 0.0009918755386024714 0.0057030632160604 0.21971753239631653 0.5140681266784668
+0.40362387895584106 0.001441026572138071 0.004902145825326443 0.16911037266254425 0.597772479057312
+0.3680979311466217 0.0016552333254367113 0.005510261747986078 0.1483861654996872 0.4731406271457672
+0.44221043586730957 0.000663818558678031 0.004247145261615515 0.2533680498600006 0.7076061964035034
+0.40299153327941895 0.002514071064069867 0.010339920409023762 0.15876498818397522 0.6188191771507263
+0.42301130294799805 0.0015653459122404456 0.008238459005951881 0.21449537575244904 0.5973901748657227
+0.3913719058036804 0.0012723127147182822 0.004862242378294468 0.17510674893856049 0.5558248162269592
+0.39601534605026245 0.0033299578353762627 0.007151035126298666 0.2510072886943817 0.5618468523025513
+0.38880419731140137 0.0012348340824246407 0.0046434104442596436 0.21815034747123718 0.5956132411956787
+0.4012744724750519 0.002525925636291504 0.006315841339528561 0.19068314135074615 0.6297574043273926
+0.4101829528808594 0.004137184005230665 0.009035997092723846 0.2136443853378296 0.5045102834701538
+0.4185577630996704 0.0009643565863370895 0.003968564327806234 0.11142775416374207 0.6740867495536804
+0.3898114562034607 0.001772760646417737 0.005463216919451952 0.08358678221702576 0.6182454228401184
+0.4186103940010071 0.0018598339520394802 0.004770010709762573 0.09266608953475952 0.6267331838607788
+0.3516678512096405 0.0006829995545558631 0.005611513275653124 0.19044235348701477 0.5170385837554932
+0.3758239150047302 0.001224306644871831 0.007374018896371126 0.16144216060638428 0.569040060043335
+0.388139545917511 0.0007125858683139086 0.004801565315574408 0.22543060779571533 0.5738520622253418
+0.4319912791252136 0.0007784130284562707 0.004301471635699272 0.13861706852912903 0.6679185032844543
+0.4268451929092407 0.0007164721610024571 0.003759245853871107 0.21676115691661835 0.6019655466079712
+0.38032859563827515 0.001129714772105217 0.004362639971077442 0.18181268870830536 0.5450412631034851
+0.373304545879364 0.0016516041941940784 0.00516157504171133 0.1191260814666748 0.5378116965293884
+0.4150649309158325 0.0008086542366072536 0.005434529390186071 0.12287945300340652 0.6910677552223206
+0.35286810994148254 0.0011541404528543353 0.005135270766913891 0.16113419830799103 0.5853803157806396
+0.3519135117530823 0.002007072325795889 0.0047774361446499825 0.2207425981760025 0.4351491332054138
+0.3664098381996155 0.001591973239555955 0.004529158119112253 0.20993788540363312 0.5344803929328918
+0.39465659856796265 0.0007078805356286466 0.004090670496225357 0.20781253278255463 0.5649989247322083
+0.3675645589828491 0.0017166421748697758 0.005264948587864637 0.10982542484998703 0.5895769000053406
+0.37158265709877014 0.0018130058888345957 0.006255144253373146 0.1363065540790558 0.5225875973701477
+0.40730276703834534 0.0007228117901831865 0.004857086110860109 0.19555920362472534 0.5367923974990845
+0.40761515498161316 0.001227556960657239 0.004153005313128233 0.12621809542179108 0.6544003486633301
+0.390079140663147 0.0012680647196248174 0.004935831297188997 0.192135751247406 0.5913920998573303
+0.3679525852203369 0.000740297429729253 0.006388411391526461 0.17636625468730927 0.5705025792121887
+0.3826093077659607 0.0006502392934635282 0.005622646305710077 0.1857645958662033 0.5480114817619324
+0.40074270963668823 0.000642158673144877 0.005481115076690912 0.22745779156684875 0.4708496630191803
+0.38186076283454895 0.0005640259478241205 0.005974270403385162 0.11751141399145126 0.5828209519386292
+0.38840270042419434 0.001551528344862163 0.004002160858362913 0.12875674664974213 0.6971627473831177
+0.3983726501464844 0.004718085750937462 0.005081729032099247 0.12927711009979248 0.6612480878829956
+0.42433953285217285 0.0024621060583740473 0.005975106731057167 0.1667223423719406 0.5639594793319702
+0.40494704246520996 0.001122754067182541 0.004529905039817095 0.22910474240779877 0.49340835213661194
+0.44293147325515747 0.004773673135787249 0.007810778915882111 0.12861166894435883 0.6851856112480164
+0.3503359258174896 0.0013454596046358347 0.004304550122469664 0.11984408646821976 0.5864096879959106
+0.41596704721450806 0.002714023692533374 0.008973727002739906 0.13800475001335144 0.7620924115180969
+0.39145249128341675 0.004481246694922447 0.018211588263511658 0.25486403703689575 0.45792585611343384
+0.40962374210357666 0.0008603386813774705 0.005847742315381765 0.25674521923065186 0.523051917552948
+0.38604462146759033 0.006788067054003477 0.03117668814957142 0.23800204694271088 0.565774142742157
+0.4280847907066345 0.0044072275049984455 0.005876705050468445 0.21456606686115265 0.7148839235305786
+0.39716625213623047 0.0014700823230668902 0.006148984655737877 0.17756010591983795 0.6029137969017029
+0.3714541792869568 0.007306160405278206 0.01660137064754963 0.19036559760570526 0.457062304019928
+0.38722729682922363 0.0019518957706168294 0.007218681275844574 0.10926567763090134 0.6396759748458862
+0.4251393675804138 0.0023918307852 0.004386812448501587 0.1673296093940735 0.7013654112815857
+0.39331239461898804 0.0045392573811113834 0.006350478157401085 0.130865678191185 0.6535371541976929
+0.3674159646034241 0.0008079889812506735 0.004268228076398373 0.15644586086273193 0.5165382027626038
+0.38683927059173584 0.001784416032023728 0.007171287201344967 0.19297586381435394 0.5123934149742126
+0.36206111311912537 0.0018497529672458768 0.00555762555450201 0.1717166006565094 0.5187073945999146
+0.4187358319759369 0.0006170244887471199 0.00458109937608242 0.14223608374595642 0.7062562108039856
+0.36546769738197327 0.0005645801429636776 0.0050693051889538765 0.16929563879966736 0.5676286816596985
+0.39991065859794617 0.0007227901369333267 0.004578022286295891 0.17825128138065338 0.5900668501853943
+0.4193187654018402 0.001074114697985351 0.005826740991324186 0.16642598807811737 0.5259043574333191
+0.36568450927734375 0.0005937638343311846 0.004633817821741104 0.2410266399383545 0.4398597478866577
+0.37844783067703247 0.0011115173110738397 0.0041538430377841 0.0906984806060791 0.7164502739906311
+0.36605364084243774 0.0009539870661683381 0.004075151868164539 0.07661056518554688 0.6930822730064392
+0.42539408802986145 0.00073765660636127 0.007669060491025448 0.17437759041786194 0.6190871000289917
+0.41400009393692017 0.0006839885609224439 0.0060769119299948215 0.1837179809808731 0.4309770166873932
+0.342412531375885 0.0008654256816953421 0.0038964960258454084 0.10380592942237854 0.5078848004341125
+0.40658432245254517 0.0006845698226243258 0.004970873706042767 0.12431128323078156 0.7468810677528381
+0.3931249976158142 0.0017716183792799711 0.0037612866144627333 0.10229279100894928 0.6641336679458618
+0.37311291694641113 0.004304242320358753 0.008607404306530952 0.19810958206653595 0.4244391620159149
+0.39737141132354736 0.0010489559499546885 0.005142331589013338 0.11801906675100327 0.5564013719558716
+0.4107265770435333 0.002049149014055729 0.0045367032289505005 0.2091553658246994 0.622759997844696
+0.3936305046081543 0.0031923882197588682 0.004535534419119358 0.09984289109706879 0.6054033637046814
+0.383392333984375 0.0007066613179631531 0.004590777214616537 0.14872097969055176 0.5762981176376343
+0.3409060835838318 0.0015074280090630054 0.008216965943574905 0.19434888660907745 0.4793667197227478
+0.35662615299224854 0.0012527832295745611 0.006866372190415859 0.15149667859077454 0.5180526375770569
+0.40076619386672974 0.0006242989329621196 0.004120826255530119 0.2533412575721741 0.5125054121017456
+0.35582080483436584 0.001173898228444159 0.0038015744648873806 0.15571443736553192 0.5603766441345215
+0.3752354681491852 0.0006752857589162886 0.004727233666926622 0.1128702163696289 0.6268022656440735
+0.3885999023914337 0.0008052477496676147 0.008208763785660267 0.159126415848732 0.6439514756202698
+0.35980886220932007 0.0007459894986823201 0.007304786238819361 0.2123909890651703 0.5028539896011353
+0.32897716760635376 0.0014748780522495508 0.006897145416587591 0.15959066152572632 0.3855094909667969
+0.34930214285850525 0.001664655515924096 0.00920945219695568 0.10494860261678696 0.6318769454956055
+0.3938162922859192 0.0009334048954769969 0.004993568640202284 0.11416914314031601 0.7024722695350647
+0.35799187421798706 0.006782091688364744 0.005582285113632679 0.25747451186180115 0.41220366954803467
+0.38116416335105896 0.00411059008911252 0.004455855116248131 0.29764071106910706 0.4798765182495117
+0.33042609691619873 0.0006183412042446434 0.005085149314254522 0.15972305834293365 0.47864821553230286
+0.3833528757095337 0.0028164663817733526 0.007762008812278509 0.15979069471359253 0.685022234916687
+0.38667815923690796 0.0008370296563953161 0.004876331426203251 0.20236089825630188 0.696013867855072
+0.340024471282959 0.0012423768639564514 0.008733261376619339 0.1832210123538971 0.45721912384033203
+0.4247669279575348 0.0009324613492935896 0.0211581289768219 0.2371588796377182 0.5540253520011902
+0.3803781569004059 0.0009851412614807487 0.0028298122342675924 0.12272180616855621 0.5223850011825562
+0.4275299310684204 0.0013639092212542892 0.0035730996169149876 0.14415767788887024 0.735467255115509
+0.3931725025177002 0.0008431164897046983 0.003947127610445023 0.12982970476150513 0.6912986040115356
+0.39491063356399536 0.00318449386395514 0.015467989258468151 0.17345722019672394 0.6199347972869873
+0.35346490144729614 0.0008490128675475717 0.011079245246946812 0.2623752951622009 0.35006481409072876
+0.34526437520980835 0.004762882832437754 0.004406699910759926 0.13299939036369324 0.4002949893474579
+0.4358685612678528 0.006535826250910759 0.008545281365513802 0.10210206359624863 0.8381836414337158
+0.4341641366481781 0.0008387371781282127 0.003037473186850548 0.2115255892276764 0.7144841551780701
+0.3984065651893616 0.010068221017718315 0.01896457001566887 0.27866825461387634 0.503275454044342
+0.38993680477142334 0.005036020185798407 0.013225597329437733 0.1885240375995636 0.4850143790245056
+0.3892534077167511 0.0021222273353487253 0.005432853475213051 0.10309706628322601 0.6622992753982544
+0.4090871512889862 0.008610673248767853 0.01338523905724287 0.13471360504627228 0.6576066613197327
+0.39982545375823975 0.0019060062477365136 0.006206015590578318 0.10532111674547195 0.6633492708206177
+0.39289727807044983 0.0034058857709169388 0.0066467393189668655 0.22335052490234375 0.4562149941921234
+0.39105144143104553 0.006588323973119259 0.005898552946746349 0.1851281374692917 0.45149147510528564
+0.3718295097351074 0.0016010792460292578 0.0029280511662364006 0.1677841991186142 0.5605844855308533
+0.42676225304603577 0.0006180388736538589 0.003981242887675762 0.06783489137887955 0.7574653625488281
+0.3602580726146698 0.0007189951138570905 0.004756419453769922 0.1526767760515213 0.6210017800331116
+0.4050089716911316 0.0010904466034844518 0.017150985077023506 0.16556887328624725 0.5098960995674133
+0.35816970467567444 0.002392412396147847 0.009322665631771088 0.2649056315422058 0.4187717139720917
+0.37841689586639404 0.008122241124510765 0.005757785867899656 0.1179070770740509 0.5190683007240295
+0.3355265259742737 0.0037819910794496536 0.007425970397889614 0.07535537332296371 0.5256601572036743
+0.37489622831344604 0.0026483219116926193 0.0024731054436415434 0.09114892780780792 0.6317735910415649
+0.3414759635925293 0.015245789662003517 0.013478509150445461 0.14757685363292694 0.5125396847724915
+0.3417526185512543 0.0080110440030694 0.010766702704131603 0.2637268006801605 0.4921133816242218
+0.3698473572731018 0.0014279646566137671 0.00462626526132226 0.09948074817657471 0.5420567989349365
+0.3448182940483093 0.012173028662800789 0.015421617776155472 0.09211202710866928 0.4615303575992584
+0.3660128712654114 0.005700836423784494 0.01008895318955183 0.16629250347614288 0.4961576461791992
+0.3481772243976593 0.0020151108037680387 0.0038800789043307304 0.09784556180238724 0.6032637357711792
+0.3949940800666809 0.012493547983467579 0.014170869253575802 0.18838240206241608 0.6284649968147278
+0.39396077394485474 0.006614100653678179 0.008577016182243824 0.1814158707857132 0.5804779529571533
+0.3623548150062561 0.0010524049866944551 0.0034060284961014986 0.14787614345550537 0.505547046661377
+0.39456266164779663 0.0071918475441634655 0.008841495960950851 0.18287760019302368 0.5320199131965637
+0.3567538857460022 0.0034959623590111732 0.004382925108075142 0.14916759729385376 0.46016666293144226
+0.40306591987609863 0.0009864266030490398 0.009205608628690243 0.11162156611680984 0.6488158702850342
+0.321266770362854 0.0035514237824827433 0.014340987429022789 0.1988515555858612 0.47277241945266724
+0.36204081773757935 0.0009433329687453806 0.00526107894256711 0.21243207156658173 0.48520198464393616
+0.40692803263664246 0.0017970945918932557 0.004130037035793066 0.1702946424484253 0.5824297070503235
+0.3493887484073639 0.0022321571595966816 0.004683372564613819 0.22221025824546814 0.4608308672904968
+0.36284857988357544 0.000580382882617414 0.002914133947342634 0.14641736447811127 0.570178747177124
+0.38101649284362793 0.002250281861051917 0.005539076402783394 0.22790482640266418 0.5522742867469788
+0.3577790856361389 0.0018835043301805854 0.0051451558247208595 0.15102793276309967 0.5642304420471191
+0.33661720156669617 0.0005916023510508239 0.0031521052587777376 0.13529624044895172 0.5210917592048645
+0.3159909248352051 0.0017179258866235614 0.006229298654943705 0.18707694113254547 0.4399082362651825
+0.3019115924835205 0.0006936615682207048 0.003983885049819946 0.14030711352825165 0.3685658574104309
+0.36095553636550903 0.0020669407676905394 0.0027300615329295397 0.12222426384687424 0.593160092830658
+0.3866303563117981 0.00491772685199976 0.009698599576950073 0.140630304813385 0.6628419160842896
+0.368391752243042 0.0014856873312965035 0.010318013839423656 0.21667848527431488 0.5041463375091553
+0.42632728815078735 0.0035402278881520033 0.004293353762477636 0.24365992844104767 0.5233582854270935
+0.3523927927017212 0.009816458448767662 0.007751386146992445 0.21690478920936584 0.4416887164115906
+0.34831365942955017 0.003200789913535118 0.0043126908130943775 0.1258208155632019 0.5394243597984314
+0.3591720759868622 0.0024994288105517626 0.007400800008326769 0.23743705451488495 0.6220476031303406
+0.3757239282131195 0.00826522521674633 0.017011262476444244 0.22731667757034302 0.5986714959144592
+0.38265740871429443 0.0021637813188135624 0.00616917759180069 0.17853739857673645 0.4565321207046509
+0.38778555393218994 0.002907153684645891 0.008630706928670406 0.20295362174510956 0.3819776177406311
+0.41137272119522095 0.005537599325180054 0.016750991344451904 0.09287281334400177 0.6881041526794434
+0.41258352994918823 0.0005325995734892786 0.0035617284011095762 0.11268889904022217 0.7965841889381409
+0.4073485732078552 0.009026650339365005 0.014909941703081131 0.16292664408683777 0.7492313385009766
+0.38704535365104675 0.006389431655406952 0.016514195129275322 0.21046626567840576 0.5174365043640137
+0.36939746141433716 0.0013747639022767544 0.005538046360015869 0.2494988888502121 0.42136579751968384
+0.3824332058429718 0.010010868310928345 0.02744889445602894 0.12760184705257416 0.5440976619720459
+0.402204692363739 0.0015451255021616817 0.010781431570649147 0.16491664946079254 0.6915866732597351
+0.3549898862838745 0.010504421778023243 0.01071908324956894 0.16269639134407043 0.5758218765258789
+0.38470110297203064 0.020215900614857674 0.03582127019762993 0.24333056807518005 0.5158803462982178
+0.3277282416820526 0.001940088695846498 0.010729463770985603 0.2901838421821594 0.36067384481430054
+0.3787209987640381 0.013749952428042889 0.008138757199048996 0.09106231480836868 0.6087870001792908
+0.35299986600875854 0.02067628875374794 0.01247147936373949 0.11075180023908615 0.5457966327667236
+0.34822678565979004 0.002413348061963916 0.005290759261697531 0.10045262426137924 0.5248703360557556
+0.37033170461654663 0.007113362662494183 0.025272028520703316 0.1579497754573822 0.509062647819519
+0.3769926428794861 0.008057202212512493 0.018563034012913704 0.1854351907968521 0.5002010464668274
+0.34124207496643066 0.0007710393983870745 0.0032790263649076223 0.12106974422931671 0.49862051010131836
+0.33905521035194397 0.0033375469502061605 0.013249875046312809 0.11514420807361603 0.5248242616653442
+0.35638368129730225 0.0013632795307785273 0.007743552792817354 0.1470356434583664 0.503702700138092
+0.3497973084449768 0.0025195337366312742 0.003960741683840752 0.10353568941354752 0.4786575138568878
+0.34290504455566406 0.005920455325394869 0.009700055234134197 0.2273598164319992 0.44589361548423767
+0.3615954518318176 0.0014171503717079759 0.0040739006362855434 0.14937371015548706 0.5787273645401001
+0.41063857078552246 0.002255499130114913 0.005986781790852547 0.09502078592777252 0.6534408926963806
+0.3447687327861786 0.0035766444634646177 0.007214525248855352 0.10661543160676956 0.4798853397369385
+0.3562401533126831 0.0007895765593275428 0.002906896872445941 0.11066828668117523 0.5358796715736389
+0.3240315318107605 0.0029903787653893232 0.005108571145683527 0.1400383561849594 0.3869553208351135
+0.373750239610672 0.003571330104023218 0.007284241728484631 0.12729494273662567 0.6582232713699341
+0.3548136055469513 0.0007519174832850695 0.004524227697402239 0.16400189697742462 0.5804428458213806
+0.3619973659515381 0.00241750362329185 0.0031697340309619904 0.09371723979711533 0.5183025002479553
+0.3355450928211212 0.004411137662827969 0.003456463571637869 0.10836211591959 0.42946869134902954
+0.3346332013607025 0.001879339455626905 0.0028432155959308147 0.09825959801673889 0.5278320908546448
+0.3438171446323395 0.0009689802536740899 0.003176557831466198 0.10796388983726501 0.6107577085494995
+0.34433290362358093 0.0038639840204268694 0.005362161435186863 0.11688365042209625 0.5748710632324219
+0.3435223698616028 0.002657021628692746 0.003931642975658178 0.2034880518913269 0.40885674953460693
+0.35076582431793213 0.0006006761104799807 0.0025020246393978596 0.11401759833097458 0.5227364897727966
+0.3373703062534332 0.0017206246266141534 0.0035373633727431297 0.15505222976207733 0.4724853038787842
+0.34637248516082764 0.0011017307406291366 0.0028698125388473272 0.1477586030960083 0.5119644999504089
+0.3513145446777344 0.0009218336781486869 0.00726600456982851 0.12019632756710052 0.5888746976852417
+0.3610309064388275 0.0011102970456704497 0.0108108539134264 0.13320830464363098 0.5671521425247192
+0.30560535192489624 0.0009807756869122386 0.005319010000675917 0.17653587460517883 0.37993836402893066
+0.3257027864456177 0.002658911980688572 0.003692498430609703 0.10859010368585587 0.40664348006248474
+0.3485310971736908 0.0018332006875425577 0.0045846146531403065 0.11767049133777618 0.48415806889533997
+0.3449050188064575 0.0009158245520666242 0.0025020958855748177 0.1115061417222023 0.5695908665657043
+0.3296607732772827 0.005204139277338982 0.007987256161868572 0.18916484713554382 0.49893292784690857
+0.32830506563186646 0.003392507554963231 0.008162638172507286 0.1808396428823471 0.3994455933570862
+0.3625112771987915 0.0011258251033723354 0.0039070588536560535 0.18737494945526123 0.5268020629882812
+0.33435767889022827 0.0053968350403010845 0.009334199130535126 0.08819041401147842 0.5340674519538879
+0.3488425612449646 0.002500239061191678 0.006869742181152105 0.1514844447374344 0.5096208453178406
+0.33112043142318726 0.0012629932025447488 0.002907715504989028 0.10916078090667725 0.4813099503517151
+0.3201713263988495 0.005787172820419073 0.0040101902559399605 0.20838741958141327 0.4245392084121704
+0.320740669965744 0.003807222004979849 0.00236813398078084 0.1035214513540268 0.5129687786102295
+0.3379800319671631 0.0006956505822017789 0.0027633351273834705 0.12774020433425903 0.5411680936813354
+0.3461045026779175 0.0012362650595605373 0.003595511196181178 0.17127850651741028 0.4656793177127838
+0.3601912260055542 0.0012345816940069199 0.002413891488686204 0.1321868896484375 0.5025163292884827
+0.3413642644882202 0.0007197507075034082 0.003950056154280901 0.16456139087677002 0.5739579200744629
+0.34552711248397827 0.0006826833123341203 0.005852046888321638 0.2004161775112152 0.5599491596221924
+0.359618604183197 0.0006502314936369658 0.0041537098586559296 0.09244237095117569 0.5675091743469238
+0.35200926661491394 0.001103531918488443 0.002183526987209916 0.11057183891534805 0.49818408489227295
+0.3304121494293213 0.001113298931159079 0.00271129934117198 0.1331394910812378 0.44337156414985657
+0.36149123311042786 0.0005402094102464616 0.0022602935787290335 0.19012552499771118 0.4823719561100006
+0.3253084719181061 0.0012913316022604704 0.0021862208377569914 0.08184932172298431 0.5185840129852295
+0.31753093004226685 0.0028138435445725918 0.003884360659867525 0.11343257874250412 0.507672905921936
+0.3595743179321289 0.0017864620313048363 0.0050223409198224545 0.14056411385536194 0.5121967792510986
+0.3536849319934845 0.0007340098381973803 0.003328867256641388 0.12524670362472534 0.569648027420044
+0.32273319363594055 0.0029743113555014133 0.002689527114853263 0.15279360115528107 0.448824942111969
+0.373492032289505 0.003335579065605998 0.0036627287045121193 0.17923924326896667 0.5133458375930786
+0.35753726959228516 0.0009023313759826124 0.0032505965791642666 0.14510777592658997 0.6017482876777649
+0.37525132298469543 0.0014470511814579368 0.004999141674488783 0.13963505625724792 0.55865079164505
+0.35799843072891235 0.002544487826526165 0.006072161253541708 0.11332839727401733 0.5032022595405579
+0.33714163303375244 0.0010513064917176962 0.0036693927831947803 0.1419985145330429 0.44708916544914246
+0.33804017305374146 0.0007594551425427198 0.0038837878964841366 0.06732478737831116 0.5973020792007446
+0.3290504217147827 0.0010580195812508464 0.00391758419573307 0.12468143552541733 0.4286574721336365
+0.2942064106464386 0.0006745666614733636 0.0045699262991547585 0.1561843603849411 0.3944116234779358
+0.3245784342288971 0.0006088166264817119 0.005356130190193653 0.12029317766427994 0.4919816851615906
+0.314312219619751 0.0006131319678388536 0.004027194809168577 0.10226630419492722 0.41906818747520447
+0.35063785314559937 0.0005574324750341475 0.0029943713452667 0.1430215686559677 0.5402827858924866
+0.3520655632019043 0.0006116416770964861 0.0032974970526993275 0.1555582731962204 0.5371075868606567
+0.33419185876846313 0.0007466298993676901 0.005017837509512901 0.12483356148004532 0.4706055521965027
+0.33757588267326355 0.0005324012017808855 0.004930454771965742 0.17805498838424683 0.5068127512931824
+0.36135387420654297 0.0007255152449943125 0.0036494634114205837 0.1631423532962799 0.521215558052063
+0.3293727934360504 0.0009109757374972105 0.003933370113372803 0.11120327562093735 0.502640426158905
+0.3631136417388916 0.00048131824587471783 0.0037806404288858175 0.22416658699512482 0.5671995878219604
+0.36675581336021423 0.001202146871946752 0.0028890217654407024 0.18034674227237701 0.5146037936210632
+0.36677882075309753 0.0021099152509123087 0.0028800691943615675 0.1395423412322998 0.553443968296051
+0.34870007634162903 0.0013162941904738545 0.0026781067717820406 0.08780667185783386 0.551288366317749
+0.3416190445423126 0.0005239943275228143 0.003030295018106699 0.1545514464378357 0.5273030996322632
+0.37837934494018555 0.0009984351927414536 0.003387085860595107 0.1244039386510849 0.6111761927604675
+0.3296346664428711 0.001593910506926477 0.005208070855587721 0.14824344217777252 0.4549228549003601
+0.3421512842178345 0.0019654002971947193 0.0032746775541454554 0.1858920454978943 0.40093758702278137
+0.3309011459350586 0.0011689785169437528 0.002598252147436142 0.15950000286102295 0.45995795726776123
+0.3376140892505646 0.000639347592368722 0.002499133348464966 0.06684693694114685 0.5250762104988098
+0.3316119909286499 0.002891097916290164 0.007124063093215227 0.12568041682243347 0.5485724806785583
+0.3429766893386841 0.0022387709468603134 0.009947764687240124 0.18252691626548767 0.41346365213394165
+0.3465558886528015 0.0009533557458780706 0.004143827594816685 0.14818638563156128 0.4971637427806854
+0.34241029620170593 0.005172128323465586 0.008191441185772419 0.1608530879020691 0.4670180678367615
+0.3186416029930115 0.0026454466860741377 0.008145295083522797 0.07912607491016388 0.4927091896533966
+0.37948834896087646 0.0016629862366244197 0.0034444069024175406 0.12722517549991608 0.6038126945495605
+0.33170086145401 0.0086435591802001 0.01038429792970419 0.15503168106079102 0.4646180272102356
+0.3542919158935547 0.0043963720090687275 0.007370410952717066 0.1985836923122406 0.5077810287475586
+0.36201012134552 0.0008876848150976002 0.005533158779144287 0.21413692831993103 0.4738970100879669
+0.34734487533569336 0.004967545159161091 0.011292948387563229 0.1915108859539032 0.5445220470428467
+0.3206760883331299 0.0018769126618281007 0.006518151611089706 0.14747245609760284 0.4542575478553772
+0.3624911904335022 0.0015972155379131436 0.003893753746524453 0.1621912717819214 0.562061607837677
+0.3744664490222931 0.005237409379333258 0.007397891022264957 0.1887626200914383 0.5507529377937317
+0.3274242579936981 0.0021150056272745132 0.003220420563593507 0.08816242963075638 0.4952486753463745
+0.3174581229686737 0.0008246933575719595 0.002822084119543433 0.08668047934770584 0.4824340045452118
+0.34795400500297546 0.0022637913934886456 0.0031168635468930006 0.12034210562705994 0.5396084189414978
+0.3424244821071625 0.0010536317713558674 0.0032981911208480597 0.0824299156665802 0.49490997195243835
+0.33510637283325195 0.0007234126678667963 0.006750392261892557 0.16973243653774261 0.4715968370437622
+0.3371499180793762 0.0006774274515919387 0.004956108517944813 0.1687631905078888 0.5004516839981079
+0.31310147047042847 0.0008425716659985483 0.002387900371104479 0.0860978364944458 0.5014434456825256
+0.34258905053138733 0.0008862854447215796 0.0024190344847738743 0.16120921075344086 0.4683844745159149
+0.3430710434913635 0.0006202304502949119 0.0021672328002750874 0.1617414355278015 0.4868769645690918
+0.3563372492790222 0.001500731916166842 0.0032406977843493223 0.15654024481773376 0.4809362292289734
+0.28531694412231445 0.0013125851983204484 0.005033436231315136 0.10724836587905884 0.41235554218292236
+0.332769513130188 0.000667584827169776 0.0034979491028934717 0.12731212377548218 0.48618119955062866
+0.3638625741004944 0.0018305765697732568 0.002958210650831461 0.1451728492975235 0.5617222785949707
+0.30239975452423096 0.0016688040923327208 0.0030413083732128143 0.14452964067459106 0.38714921474456787
+0.31459569931030273 0.000568003801163286 0.002905329456552863 0.1247892677783966 0.4703259766101837
+0.3545984625816345 0.0011030270252376795 0.0035416588652879 0.08593428879976273 0.5363925099372864
+0.34861546754837036 0.0017970707267522812 0.004185972735285759 0.09313815832138062 0.5450232625007629
+0.30603256821632385 0.0008694014977663755 0.003689580364152789 0.1904570460319519 0.4315025210380554
+0.2838924527168274 0.0007188473246060312 0.00332041853107512 0.09143231064081192 0.40608084201812744
+0.3093990981578827 0.0015418360708281398 0.004087251145392656 0.08350346982479095 0.5014048218727112
+0.3205983638763428 0.0008194753318093717 0.003850359469652176 0.2160613238811493 0.4630083739757538
+0.32249927520751953 0.0005785896792076528 0.0043711550533771515 0.2942065894603729 0.39719125628471375
+0.3265811800956726 0.0007611963665112853 0.002909808186814189 0.11918129771947861 0.4004608988761902
+0.32958102226257324 0.001012683380395174 0.002848715288564563 0.10386863350868225 0.49680304527282715
+0.3385058641433716 0.0011488450691103935 0.0027052192017436028 0.07642406970262527 0.5583894848823547
+0.32587355375289917 0.0011144254822283983 0.003617284819483757 0.11317554116249084 0.46443942189216614
+0.302608847618103 0.0007309539942070842 0.00416211225092411 0.1703610122203827 0.41501834988594055
+0.3173374533653259 0.0015708600403741002 0.0038740760646760464 0.14656953513622284 0.488237202167511
+0.33100858330726624 0.0020817925687879324 0.005227773915976286 0.13738225400447845 0.46896055340766907
+0.31517189741134644 0.0004969023284502327 0.0026101870462298393 0.16414088010787964 0.43219447135925293
+0.3079584836959839 0.0020939435344189405 0.0035687685012817383 0.12167056649923325 0.473824143409729
+0.3215462565422058 0.003491343930363655 0.006785344332456589 0.2045665681362152 0.4654960632324219
+0.30090972781181335 0.0009581808699294925 0.005934230517596006 0.11130481213331223 0.44637173414230347
+0.35316163301467896 0.0019183752592653036 0.0030644137877970934 0.1777619570493698 0.44177836179733276
+0.2677460014820099 0.004600379150360823 0.0032312078401446342 0.15945744514465332 0.30784711241722107
+0.33903974294662476 0.0022209223825484514 0.003519318299368024 0.10560313612222672 0.6018669009208679
+0.3464590907096863 0.0007291766814887524 0.007146249059587717 0.1301499754190445 0.5292196273803711
+0.3000906705856323 0.001665875781327486 0.007515473756939173 0.18062536418437958 0.40962502360343933
+0.2939777672290802 0.0005211451789364219 0.0028972462750971317 0.09632651507854462 0.32632964849472046
+0.31384336948394775 0.000962148595135659 0.007462163455784321 0.09594620019197464 0.5038414001464844
+0.34200727939605713 0.00043149531120434403 0.005449090618640184 0.06062030419707298 0.5580022931098938
+0.3487141728401184 0.0028612492606043816 0.0036716635804623365 0.11178695410490036 0.5219113826751709
+0.3435836732387543 0.003764144843444228 0.006939626298844814 0.20554685592651367 0.37090328335762024
+0.3295470178127289 0.0004753095854539424 0.003532696980983019 0.08778329938650131 0.5242910981178284
+0.32715490460395813 0.002418175106868148 0.007563352584838867 0.13053420186042786 0.5290947556495667
+0.28308263421058655 0.0014092333149164915 0.005358255002647638 0.09520315378904343 0.39476749300956726
+0.33446764945983887 0.001067152712494135 0.003066589590162039 0.21600191295146942 0.4502127766609192
+0.33792412281036377 0.0038160658441483974 0.008589835837483406 0.08372760564088821 0.4684474766254425
+0.32530477643013 0.0016527610132470727 0.006953096948564053 0.170667365193367 0.5006933808326721
+0.3420689105987549 0.0013600284000858665 0.003909342922270298 0.09834107011556625 0.5145297646522522
+0.3238297998905182 0.004463919438421726 0.0033109227661043406 0.08938782662153244 0.45520472526550293
+0.3525996506214142 0.002820714144036174 0.0029640779830515385 0.18437613546848297 0.38503777980804443
+0.30793148279190063 0.0006875635008327663 0.004165351390838623 0.15192225575447083 0.536950409412384
+0.3360435962677002 0.0017305203946307302 0.005666991230100393 0.12760692834854126 0.4977177679538727
+0.30335861444473267 0.00175628496799618 0.004781376104801893 0.11478018015623093 0.5032752156257629
+0.29840153455734253 0.0005318630137480795 0.002634284319356084 0.19391177594661713 0.35120680928230286
+0.2813180387020111 0.0009247589623555541 0.003299084724858403 0.09163845330476761 0.3601134121417999
+0.2906022071838379 0.0009198289480991662 0.0033040717244148254 0.17895393073558807 0.43570300936698914
+0.33825603127479553 0.0005159893771633506 0.0035277537535876036 0.15890999138355255 0.5800608992576599
+0.3162655532360077 0.0011356297181919217 0.005418646149337292 0.29882916808128357 0.41825366020202637
+0.32048743963241577 0.0006777588278055191 0.005925607867538929 0.1283835619688034 0.482187956571579
+0.3558042645454407 0.0006456669652834535 0.0035891029983758926 0.21124224364757538 0.5398684740066528
+0.30855971574783325 0.000677484436891973 0.003745464840903878 0.13389788568019867 0.4921908974647522
+0.33166825771331787 0.0004419171600602567 0.0027395247016102076 0.17397598922252655 0.437843382358551
+0.32223206758499146 0.0008126808097586036 0.0032981201075017452 0.18902717530727386 0.43794190883636475
+0.30766546726226807 0.0006563842762261629 0.0033524122554808855 0.16832603514194489 0.4849390685558319
+0.33594614267349243 0.00044803539640270174 0.0026144692674279213 0.18032361567020416 0.4613492488861084
+0.30539149045944214 0.0005686894292011857 0.0024785909336060286 0.1195533275604248 0.4857805073261261
+0.3099800944328308 0.0007723906892351806 0.0026981602422893047 0.10336528718471527 0.4713127315044403
+0.3087191581726074 0.0006616658065468073 0.00302203674800694 0.17909477651119232 0.46350565552711487
+0.299526184797287 0.0010479650227352977 0.0028079829644411802 0.21169798076152802 0.40237608551979065
+0.2722978889942169 0.0012862208532169461 0.002295341342687607 0.07819163799285889 0.4058336913585663
+0.3613901734352112 0.0012159427860751748 0.0021532895043492317 0.13960859179496765 0.5727534294128418
+0.3004056215286255 0.0006197550683282316 0.002656059805303812 0.18666569888591766 0.3942602574825287
+0.31706011295318604 0.0009974729036912322 0.0032717136200517416 0.21005383133888245 0.44006749987602234
+0.33007657527923584 0.0007451645215041935 0.0030669872649013996 0.17280560731887817 0.502031683921814
+0.33288776874542236 0.0005571067449636757 0.00239939265884459 0.10739364475011826 0.515995442867279
+0.32171571254730225 0.0005824488471262157 0.002919516758993268 0.15091994404792786 0.5016687512397766
+0.293112576007843 0.0025833467952907085 0.003270347835496068 0.16117985546588898 0.41256317496299744
+0.3110606074333191 0.0006223527598194778 0.002669485053047538 0.10873401165008545 0.3901715576648712
+0.33498820662498474 0.0006518106674775481 0.002843023044988513 0.0658649429678917 0.49949005246162415
+0.33710765838623047 0.0006836166139692068 0.003265080275014043 0.23233970999717712 0.5265589952468872
+0.3497334420681 0.0007153485785238445 0.004221930168569088 0.13552233576774597 0.5987433791160583
+0.2787758708000183 0.0005592535599134862 0.004656950943171978 0.1728542298078537 0.3372558355331421
+0.30991965532302856 0.0012014296371489763 0.00294483988545835 0.18797491490840912 0.3142653703689575
+0.30301856994628906 0.0014531449414789677 0.004775413777679205 0.16810005903244019 0.4129755198955536
+0.34773027896881104 0.0005912953056395054 0.0025414973497390747 0.18946605920791626 0.6115089654922485
+0.33679044246673584 0.00429065665230155 0.005091287195682526 0.16537170112133026 0.5326866507530212
+0.3449426293373108 0.003669008379802108 0.007901390083134174 0.21072842180728912 0.4171246886253357
+0.29905277490615845 0.0006382808205671608 0.003221221035346389 0.13979774713516235 0.38014116883277893
+0.3532235026359558 0.004878039471805096 0.007379805203527212 0.07316126674413681 0.5790837407112122
+0.3017949163913727 0.0020687608048319817 0.004340203944593668 0.1229458898305893 0.4947505593299866
+0.306598424911499 0.0012246418045833707 0.004808921832591295 0.09948068112134933 0.49713948369026184
+0.3039783239364624 0.003493937198072672 0.007107059471309185 0.14015595614910126 0.42215731739997864
+0.2690722644329071 0.0008873622282408178 0.002323784399777651 0.11279144138097763 0.32908809185028076
+0.3190477192401886 0.0010411967523396015 0.004491531290113926 0.07007822394371033 0.48087507486343384
+0.32512781023979187 0.0007407405646517873 0.0026188166812062263 0.10588552057743073 0.5655956864356995
+0.3095417022705078 0.0014639428118243814 0.007164421491324902 0.10641340911388397 0.4704495966434479
+0.338018536567688 0.0013600497040897608 0.014011386781930923 0.13795730471611023 0.4728070795536041
+0.28325164318084717 0.0013255556114017963 0.003896336304023862 0.10652624815702438 0.3580278754234314
+0.31027811765670776 0.005483152344822884 0.006232240237295628 0.09626751393079758 0.4392992854118347
+0.2896445393562317 0.0016997928032651544 0.004437639378011227 0.14932681620121002 0.4141322672367096
+0.30821743607521057 0.0029116582591086626 0.0035962064284831285 0.12347913533449173 0.4878503084182739
+0.2885040044784546 0.007845036685466766 0.013346950523555279 0.20264862477779388 0.4025478959083557
+0.3219020366668701 0.001843289821408689 0.006270132958889008 0.19001300632953644 0.4770919680595398
+0.29534855484962463 0.003126503201201558 0.005221279338002205 0.13342483341693878 0.3858465552330017
+0.29274559020996094 0.0070245168171823025 0.010151642374694347 0.08821474015712738 0.38928675651550293
+0.2951754033565521 0.0010481843492016196 0.00443563936278224 0.1281713992357254 0.43181338906288147
+0.3304411470890045 0.003935304470360279 0.005215399898588657 0.1302785724401474 0.5437004566192627
+0.2902466952800751 0.0075881388038396835 0.009262826293706894 0.2638900876045227 0.4169091582298279
+0.3479556143283844 0.0014075960498303175 0.0036267819814383984 0.2148664891719818 0.4652866721153259
+0.2959524989128113 0.0025652956683188677 0.005433619488030672 0.10351363569498062 0.4365942180156708
+0.31675097346305847 0.004228804260492325 0.00613761693239212 0.13270962238311768 0.42283937335014343
+0.30253246426582336 0.0008094661752693355 0.002949808957055211 0.11690457910299301 0.4279620051383972
+0.30568379163742065 0.001930341124534607 0.0053362660109996796 0.16733863949775696 0.4314332604408264
+0.296597421169281 0.002225207630544901 0.005725162569433451 0.1089102253317833 0.432807594537735
+0.28144100308418274 0.0004747166531160474 0.0022641178220510483 0.07877806574106216 0.4618414342403412
+0.3797665238380432 0.0011386544210836291 0.003748225746676326 0.10833247005939484 0.5359636545181274
+0.33138567209243774 0.0008256238070316613 0.003163381712511182 0.17243742942810059 0.4423467814922333
+0.36432069540023804 0.00046340562403202057 0.002765907673165202 0.15476901829242706 0.44632965326309204
+0.2847415506839752 0.0007261406863108277 0.0030802451074123383 0.0735323429107666 0.44857141375541687
+0.33382177352905273 0.0007039293413981795 0.0024549670051783323 0.1452091485261917 0.5648823380470276
+0.29398179054260254 0.0006580743938684464 0.002832072088494897 0.09778683632612228 0.5075559020042419
+0.29568639397621155 0.0005493510980159044 0.003099514404311776 0.09214960783720016 0.4248122572898865
+0.3308369219303131 0.0008945927838794887 0.00320448842830956 0.1421794444322586 0.39010089635849
+0.34249311685562134 0.0016512422589585185 0.002753426553681493 0.1843545138835907 0.41385796666145325
+0.3622179627418518 0.0012188961263746023 0.002578926272690296 0.09215983003377914 0.6292164921760559
+0.3451022207736969 0.0007112182793207467 0.01911330223083496 0.16364231705665588 0.597374439239502
+0.30990511178970337 0.001249894849024713 0.007633065804839134 0.12423466145992279 0.3391735851764679
+0.2947191596031189 0.0006351461634039879 0.0030517028644680977 0.08479609340429306 0.4021596312522888
+0.30557429790496826 0.0012309981975704432 0.004177062772214413 0.09374605119228363 0.46310579776763916
+0.2867680788040161 0.0006481502205133438 0.00414231838658452 0.05837222933769226 0.4409056305885315
+0.3194391131401062 0.0010610141325742006 0.0026941250544041395 0.18385496735572815 0.5067129135131836
+0.32352370023727417 0.002384550403803587 0.004693089053034782 0.19791094958782196 0.43601754307746887
+0.30215075612068176 0.0008221607422456145 0.002995165530592203 0.14785777032375336 0.37114524841308594
+0.3043128252029419 0.0006154498551040888 0.004676981829106808 0.13056133687496185 0.46285054087638855
+0.3392881453037262 0.0005760826752521098 0.004399306606501341 0.06547854095697403 0.5446019172668457
+0.28843289613723755 0.0008335462189279497 0.0025216988287866116 0.15572769939899445 0.3474982678890228
+0.31556546688079834 0.0015772600891068578 0.005359266418963671 0.1321074664592743 0.3434607982635498
+0.326283723115921 0.0005039817770011723 0.002878352766856551 0.08618029206991196 0.5266990661621094
+0.3049750030040741 0.00117337831761688 0.002289997413754463 0.07749830186367035 0.49738070368766785
+0.2797914743423462 0.0014350197743624449 0.002703857608139515 0.07767447084188461 0.43249770998954773
+0.3124253749847412 0.0007030105334706604 0.0035724807530641556 0.0831509530544281 0.4279922544956207
+0.29434341192245483 0.0004754052497446537 0.003761444240808487 0.17132119834423065 0.3918013274669647
+0.30018532276153564 0.0004971917369402945 0.002587190130725503 0.15665490925312042 0.40375977754592896
+0.29770058393478394 0.0004633823409676552 0.002430130261927843 0.15859277546405792 0.43639472126960754
+0.29493582248687744 0.0004922208609059453 0.0022852483671158552 0.11833740025758743 0.46968352794647217
+0.3133953809738159 0.0005426423740573227 0.0032211770303547382 0.150558739900589 0.4692138433456421
+0.2888950705528259 0.0004928911221213639 0.002667932538315654 0.1990804374217987 0.3439381718635559
+0.3430393934249878 0.0006283959955908358 0.003290510270744562 0.25893768668174744 0.4587298035621643
+0.3276526629924774 0.0005042645498178899 0.0037926076911389828 0.09723420441150665 0.5496683716773987
+0.2757677733898163 0.0009919734438881278 0.002502505900338292 0.11011814326047897 0.4224356710910797
+0.29096531867980957 0.002007243689149618 0.004266092088073492 0.16893374919891357 0.4420429468154907
+0.3112475275993347 0.0007912548026069999 0.002942124381661415 0.10506602376699448 0.4138551354408264
+0.3201406002044678 0.0012339026434347034 0.0028851397801190615 0.19361019134521484 0.3874245882034302
+0.3340182304382324 0.0016283165896311402 0.0031061191111803055 0.07374568283557892 0.5608612298965454
+0.3294014632701874 0.0005201580934226513 0.0025198673829436302 0.19710473716259003 0.4950604736804962
+0.2733314633369446 0.001559721422381699 0.008010623045265675 0.11266371607780457 0.34477540850639343
+0.2706471085548401 0.0007366430945694447 0.005316266790032387 0.1368030309677124 0.4118717312812805
+0.28643178939819336 0.0016675236402079463 0.0026617522817105055 0.09806463122367859 0.4163844585418701
+0.29633140563964844 0.0023105626460164785 0.004385207314044237 0.11743713170289993 0.4860536754131317
+0.29036444425582886 0.0006775858928449452 0.0021863726433366537 0.17692475020885468 0.36605456471443176
+0.3092937767505646 0.003011127933859825 0.004475366789847612 0.12174046039581299 0.4237658679485321
+0.2773400843143463 0.002583924913778901 0.004881455563008785 0.12906616926193237 0.3560630977153778
+0.33014634251594543 0.0005575285758823156 0.002676288830116391 0.1127941757440567 0.5135669708251953
+0.2967175245285034 0.0015566879883408546 0.002635219832882285 0.07261941581964493 0.4749397039413452
+0.28064608573913574 0.0016746899345889688 0.0033167945221066475 0.12754225730895996 0.36404597759246826
+0.2760560214519501 0.0007600801181979477 0.0031875832937657833 0.19195905327796936 0.38027259707450867
+0.2754448652267456 0.0006888805655762553 0.0028162309899926186 0.08683769404888153 0.42780208587646484
+0.27935606241226196 0.0009891381487250328 0.0031729687470942736 0.06418884545564651 0.45253556966781616
+0.29356494545936584 0.0010949982097372413 0.003326197387650609 0.08766651898622513 0.38374602794647217
+0.288376122713089 0.0006230506696738303 0.0025217412039637566 0.11831764131784439 0.39269569516181946
+0.2955041825771332 0.0006511606043204665 0.0032440091017633677 0.1983402818441391 0.4354826509952545
+0.31635308265686035 0.0009217532933689654 0.0038458192721009254 0.16074371337890625 0.4930310547351837
+0.29713180661201477 0.000829943164717406 0.003828872926533222 0.12592965364456177 0.3988230526447296
+0.28443631529808044 0.0005528389592655003 0.0031428050715476274 0.08269702643156052 0.38368991017341614
+0.29210028052330017 0.001666680327616632 0.0025334740057587624 0.10824072360992432 0.3631027340888977
+0.2856500446796417 0.0010878167813643813 0.0022422249894589186 0.10054971277713776 0.4598011374473572
+0.27653923630714417 0.0018888748018071055 0.0032831670250743628 0.1784064769744873 0.4419822096824646
+0.31692081689834595 0.0010516969487071037 0.0045209466479718685 0.10125528275966644 0.4429921805858612
+0.2924119532108307 0.0011723017087206244 0.003261455800384283 0.15131768584251404 0.37303340435028076
+0.3130512833595276 0.0035502463579177856 0.005378656089305878 0.15615180134773254 0.47153329849243164
+0.3130914568901062 0.0012314211344346404 0.00390630355104804 0.07468269020318985 0.5431632995605469
+0.3044896721839905 0.0017094939248636365 0.00238116760738194 0.13783299922943115 0.4582479000091553
+0.3513648808002472 0.00458763400092721 0.003259846940636635 0.18645228445529938 0.3513258099555969
+0.31135717034339905 0.0016622766852378845 0.001974083948880434 0.1270127147436142 0.3884624242782593
+0.28780341148376465 0.0006033911486156285 0.0024857281241565943 0.12755924463272095 0.4530198574066162
+0.29022932052612305 0.000933499017264694 0.002598581602796912 0.16050291061401367 0.511339008808136
+0.29853880405426025 0.0007294333772733808 0.007436675485223532 0.1408950537443161 0.44846537709236145
+0.31369853019714355 0.003947009798139334 0.006625031586736441 0.24475397169589996 0.34065356850624084
+0.3123728632926941 0.0021807353477925062 0.0021503749303519726 0.13323111832141876 0.4122824966907501
+0.2873290181159973 0.0014324296498671174 0.00246976176276803 0.08036478608846664 0.4065547585487366
+0.31744489073753357 0.0011165085015818477 0.002321577398106456 0.08271180093288422 0.5606828331947327
+0.3187313675880432 0.004906275309622288 0.01025004219263792 0.1497262716293335 0.6064118146896362
+0.28837066888809204 0.0023503154516220093 0.011542213149368763 0.1448502093553543 0.4246469736099243
+0.3237486481666565 0.002707174513489008 0.004047634080052376 0.20118552446365356 0.2280256599187851
+0.31451839208602905 0.01010366715490818 0.016874153167009354 0.0982336550951004 0.4049069881439209
+0.3099513053894043 0.0017071731854230165 0.010068519972264767 0.10590940713882446 0.5196599364280701
+0.3496437072753906 0.00890024658292532 0.004831824917346239 0.15963110327720642 0.61163729429245
+0.32054609060287476 0.021912546828389168 0.021403124555945396 0.2309969961643219 0.40668538212776184
+0.3297008275985718 0.0033625599462538958 0.006621577776968479 0.1865454912185669 0.3725408613681793
+0.320679634809494 0.008688854984939098 0.011342233046889305 0.14932481944561005 0.42282140254974365
+0.33850541710853577 0.016079740598797798 0.018599223345518112 0.14513379335403442 0.4782325327396393
+0.2936311960220337 0.0010376318823546171 0.0024103547912091017 0.07634148746728897 0.47372370958328247
+0.30756840109825134 0.008701787330210209 0.018522487953305244 0.18038548529148102 0.39594966173171997
+0.3230786919593811 0.008868813514709473 0.01550524402409792 0.1960517168045044 0.46873918175697327
+0.2865709662437439 0.000558509724214673 0.002102535218000412 0.14548300206661224 0.39600619673728943
+0.29769694805145264 0.005562028381973505 0.007716604974120855 0.0697503536939621 0.4927820861339569
+0.3161795139312744 0.0031837697606533766 0.0039390441961586475 0.1695433109998703 0.47575318813323975
+0.3409470319747925 0.00086348841432482 0.007235545665025711 0.23756304383277893 0.3691563010215759
+0.30553191900253296 0.002006995026022196 0.009030274115502834 0.16969513893127441 0.41306272149086
+0.27789953351020813 0.0004995795898139477 0.002753246808424592 0.14434674382209778 0.4003441631793976
+0.31692224740982056 0.0009961556643247604 0.005848935339599848 0.160671204328537 0.5224803686141968
+0.25913870334625244 0.0019574312027543783 0.0031668459996581078 0.10902160406112671 0.36431896686553955
+0.28887027502059937 0.0020775182638317347 0.0035658804699778557 0.12560175359249115 0.46623024344444275
+0.30533063411712646 0.002238806104287505 0.006205801852047443 0.2106168419122696 0.3609752357006073
+0.3141798973083496 0.0005069102626293898 0.0027213445864617825 0.09239278733730316 0.4271652400493622
+0.298757404088974 0.0027737983036786318 0.006234869826585054 0.13354596495628357 0.45462852716445923
+0.26273953914642334 0.0010398655431345105 0.003359920112416148 0.11953216791152954 0.39865824580192566
+0.29001498222351074 0.0015819516265764832 0.004298403859138489 0.22047244012355804 0.40686503052711487
+0.31186264753341675 0.003494460368528962 0.009642643854022026 0.13079853355884552 0.39818423986434937
+0.29333752393722534 0.0005965268355794251 0.004094711504876614 0.07896273583173752 0.3974165916442871
+0.3096131384372711 0.002749871462583542 0.003416562220081687 0.14969612658023834 0.4556787610054016
+0.3153846263885498 0.0031492169946432114 0.004494873341172934 0.08546715974807739 0.49637913703918457
+0.29821860790252686 0.0005312534631229937 0.002470102859660983 0.14631932973861694 0.42002102732658386
+0.2928595542907715 0.0023274796549230814 0.006027881521731615 0.15203237533569336 0.4229349195957184
+0.2795664370059967 0.0016722057480365038 0.00353312399238348 0.18093587458133698 0.38459286093711853
+0.2608511745929718 0.0004620934196282178 0.004910466261208057 0.11100414395332336 0.40117916464805603
+0.2799926996231079 0.0010148079600185156 0.003991978708654642 0.07605695724487305 0.4698259234428406
+0.23824110627174377 0.0004757199785672128 0.002144239144399762 0.0925430878996849 0.3046468496322632
+0.25926288962364197 0.0013453153660520911 0.003530274610966444 0.1070982962846756 0.3172129690647125
+0.28728610277175903 0.0015634590527042747 0.00592017313465476 0.0803348645567894 0.457335501909256
+0.3111264109611511 0.0007462575449608266 0.004461858421564102 0.09943090379238129 0.45769765973091125
+0.31806161999702454 0.002297259634360671 0.0032742845360189676 0.19849833846092224 0.5251034498214722
+0.30569884181022644 0.0029414063319563866 0.002498371759429574 0.16144295036792755 0.3717106580734253
+0.3197728991508484 0.00131047866307199 0.002401952398940921 0.127800852060318 0.4154801666736603
+0.2895912528038025 0.0004901557695120573 0.0021756477653980255 0.17379845678806305 0.4059033989906311
+0.2308446168899536 0.0018036863766610622 0.0023935181088745594 0.1364157646894455 0.3254898488521576
+0.2847112715244293 0.0019589432049542665 0.002740156603977084 0.10980462282896042 0.4361252188682556
+0.28875821828842163 0.0006377780809998512 0.0022768662311136723 0.09698760509490967 0.4454672932624817
+0.33899977803230286 0.0006993146962486207 0.003158277366310358 0.25089916586875916 0.4050368666648865
+0.26445886492729187 0.0011643352918326855 0.0036236648447811604 0.1084240972995758 0.28011661767959595
+0.3027992248535156 0.0005955079686827958 0.0031524389050900936 0.2018207162618637 0.5336143374443054
+0.2959417402744293 0.0011129723861813545 0.0031633353792130947 0.14608946442604065 0.4926775097846985
+0.27103477716445923 0.0019589420408010483 0.004895226564258337 0.18122781813144684 0.2860729694366455
+0.2968405783176422 0.0006126051302999258 0.0023446993436664343 0.13739538192749023 0.3301770091056824
+0.24653884768486023 0.0012711116578429937 0.0033028468023985624 0.16428180038928986 0.32934778928756714
+0.2578433156013489 0.0010599468369036913 0.0029429439455270767 0.10477958619594574 0.4321378767490387
+0.3119875192642212 0.0007934666937217116 0.004902561195194721 0.14820991456508636 0.5589194893836975
+0.2453588843345642 0.001460670493543148 0.012593643739819527 0.15168480575084686 0.33444684743881226
+0.3298189043998718 0.0006760517135262489 0.0036928770132362843 0.13547420501708984 0.3202681541442871
+0.2627313733100891 0.0038138465024530888 0.005479378625750542 0.08995700627565384 0.3799292743206024
+0.294503390789032 0.0019290235359221697 0.007164498325437307 0.09024892002344131 0.4425036311149597
+0.27058422565460205 0.0016646472504362464 0.002260347129777074 0.08895358443260193 0.46888625621795654
+0.31482699513435364 0.007411914877593517 0.006590589880943298 0.17603690922260284 0.3775053322315216
+0.2797086834907532 0.0028043054044246674 0.0027931721415370703 0.11027936637401581 0.3473265767097473
+0.27832120656967163 0.0011927776504307985 0.00538651691749692 0.16636958718299866 0.3760579824447632
+0.3011448085308075 0.003310735570266843 0.008108649402856827 0.119601309299469 0.5419957041740417
+0.2666294276714325 0.0005829589790664613 0.0027027197647839785 0.13764478266239166 0.37681320309638977
+0.29510754346847534 0.0028363435994833708 0.016699012368917465 0.1674472838640213 0.28673961758613586
+0.31264185905456543 0.0013583494583144784 0.006488884333521128 0.17250974476337433 0.4219031035900116
+0.26353919506073 0.0013658973621204495 0.0017086225561797619 0.0790453627705574 0.42755165696144104
+0.27578499913215637 0.0030208230018615723 0.0031245634891092777 0.11023195832967758 0.3903992176055908
+0.29417699575424194 0.000979014323092997 0.001874817069619894 0.06888066232204437 0.39655235409736633
+0.2593616843223572 0.0012064490001648664 0.00647591520100832 0.10532484203577042 0.36749354004859924
+0.2525744140148163 0.001559915952384472 0.007131057325750589 0.14543040096759796 0.34971871972084045
+0.26052922010421753 0.0007837580051273108 0.0022641296964138746 0.08116564899682999 0.39531034231185913
+0.2726166844367981 0.0020866107661277056 0.0036050179041922092 0.15575943887233734 0.3394823372364044
+0.28754353523254395 0.0010840472532436252 0.0033614023122936487 0.12770892679691315 0.34289318323135376
+0.27586135268211365 0.0011583664454519749 0.0026009692810475826 0.15069477260112762 0.37603893876075745
+0.27731287479400635 0.0026356931775808334 0.005280293058604002 0.1207272857427597 0.4427005648612976
+0.29808878898620605 0.0010503308149054646 0.003633715445175767 0.10814366489648819 0.5483391880989075
+0.28716495633125305 0.0011700099566951394 0.0035369633696973324 0.08999980241060257 0.3211040198802948
+0.2884010672569275 0.00234952662140131 0.0071036783047020435 0.11830358952283859 0.334404855966568
+0.2866605818271637 0.0007347449427470565 0.00505661079660058 0.10955262929201126 0.4551457166671753
+0.29699814319610596 0.0027382613625377417 0.0023667311761528254 0.1832997351884842 0.5133506655693054
+0.2830509841442108 0.006114857736974955 0.007169796619564295 0.1266407072544098 0.436006098985672
+0.2900640070438385 0.001503663370385766 0.004189653787761927 0.21904219686985016 0.35659345984458923
+0.28065311908721924 0.003021196462213993 0.0029872653540223837 0.19700460135936737 0.3195440471172333
+0.3025938868522644 0.0055440207943320274 0.0036247342359274626 0.17040446400642395 0.4333551824092865
+0.3229556083679199 0.0011009195586666465 0.004006985109299421 0.07772962749004364 0.5187931060791016
+0.3169017434120178 0.002293557161465287 0.015476357191801071 0.18422545492649078 0.4388292729854584
+0.31920289993286133 0.0014891382306814194 0.012303742580115795 0.2902531921863556 0.31560879945755005
+0.29828864336013794 0.0015553537523373961 0.002557465573772788 0.1327410191297531 0.38493290543556213
+0.27690085768699646 0.003483130596578121 0.007489453069865704 0.12340560555458069 0.44885796308517456
+0.2658756971359253 0.0005329818231984973 0.0033937699627131224 0.10745254158973694 0.3933957815170288
+0.3057364225387573 0.004427703097462654 0.0033796143252402544 0.15428350865840912 0.40222933888435364
+0.2777552008628845 0.005264636594802141 0.004161601886153221 0.12916502356529236 0.40417322516441345
+0.2345508337020874 0.0007327410276047885 0.0020150637719780207 0.11934208869934082 0.317030131816864
+0.28830140829086304 0.0016497320029884577 0.006607755087316036 0.13544347882270813 0.4009976387023926
+0.2745536267757416 0.0011107024038210511 0.003314772853627801 0.1122717335820198 0.3397591710090637
+0.27767181396484375 0.0011877174256369472 0.003782873274758458 0.11375461518764496 0.4613057076931
+0.28061503171920776 0.002606427064165473 0.012700955383479595 0.12295372039079666 0.44033709168434143
+0.249683678150177 0.0007092565065249801 0.006837155669927597 0.18602433800697327 0.3745764493942261
+0.2994685769081116 0.004197238944470882 0.0023698117583990097 0.1903829723596573 0.4574669599533081
+0.25501900911331177 0.004965259227901697 0.0028424577321857214 0.12185335159301758 0.33128228783607483
+0.298704594373703 0.0007431801059283316 0.0023615104146301746 0.14805729687213898 0.46395695209503174
+0.27227139472961426 0.0028342027217149734 0.007301037665456533 0.15466617047786713 0.3762206733226776
+0.24835188686847687 0.003096314612776041 0.006932745687663555 0.10908409208059311 0.36268141865730286
+0.2402542680501938 0.0005826167180202901 0.0026796869933605194 0.08580779284238815 0.3566632568836212
+0.2664770483970642 0.0022864106576889753 0.00643550930544734 0.09945907443761826 0.3522396385669708
+0.30411913990974426 0.0014944117283448577 0.006705111358314753 0.14709460735321045 0.35890811681747437
+0.2557223439216614 0.0009831020142883062 0.002112958114594221 0.07577169686555862 0.41434088349342346
+0.28586050868034363 0.004828779958188534 0.0051986719481647015 0.1871764063835144 0.4585556983947754
+0.2550796866416931 0.0030559285078197718 0.004802807234227657 0.1791435331106186 0.2981588840484619
+0.23800146579742432 0.0006380299455486238 0.0025086344685405493 0.0892954170703888 0.31109943985939026
+0.26491081714630127 0.004068575333803892 0.007010991219431162 0.09625294059515 0.32280272245407104
+0.28712335228919983 0.0019320267019793391 0.004291148856282234 0.10625658929347992 0.526618480682373
+0.2769935429096222 0.0012427873443812132 0.003924362361431122 0.16057690978050232 0.3677065670490265
+0.279593825340271 0.003744987305253744 0.007601809687912464 0.21996645629405975 0.30933892726898193
+0.2826142907142639 0.0008019024971872568 0.0025286837480962276 0.12994781136512756 0.3944236636161804
+0.2702380418777466 0.001234649564139545 0.0032362404745072126 0.2015530914068222 0.36482003331184387
+0.2575107216835022 0.0013940365752205253 0.002961635123938322 0.21689863502979279 0.3697492182254791
+0.2522062659263611 0.0006253760075196624 0.0034145924728363752 0.08805305510759354 0.3598381578922272
+0.21571898460388184 0.0014503997517749667 0.009488342329859734 0.15340080857276917 0.28945013880729675
+0.25113049149513245 0.0005453971098177135 0.004927096422761679 0.09556568413972855 0.3915793299674988
+0.2805093824863434 0.0018608500249683857 0.0022761886939406395 0.05867602303624153 0.45602884888648987
+0.24562986195087433 0.0024881449062377214 0.003417250467464328 0.11559540778398514 0.35458657145500183
+0.3068500757217407 0.0005623301258310676 0.00283606955781579 0.1243356466293335 0.4219358265399933
+0.274421751499176 0.0016048381803557277 0.002557594096288085 0.1574782133102417 0.30069682002067566
+0.2769673466682434 0.002558052074164152 0.0023939348757267 0.2084708958864212 0.365480899810791
+0.28026750683784485 0.0014753028517588973 0.001806251355446875 0.16378284990787506 0.4409027099609375
+0.24624022841453552 0.0005123261362314224 0.002454038942232728 0.14631237089633942 0.40454041957855225
+0.2571377754211426 0.0005822476232424378 0.0025770911015570164 0.14209038019180298 0.36995357275009155
+0.25799959897994995 0.0009646114194765687 0.002331568393856287 0.16868989169597626 0.2684428095817566
+0.27405673265457153 0.001080579124391079 0.0019389786757528782 0.08136668801307678 0.3534948229789734
+0.2922815680503845 0.0007029320695437491 0.002348515437915921 0.05019014701247215 0.5161943435668945
+0.265203595161438 0.0012895059771835804 0.0037647895514965057 0.11328120529651642 0.4176510274410248
+0.25937700271606445 0.000840312335640192 0.004742372781038284 0.12010320276021957 0.36236247420310974
+0.2504657804965973 0.0005463780253194273 0.0025037010200321674 0.07859045267105103 0.3116062581539154
+0.25739914178848267 0.0013212086632847786 0.0031969703268259764 0.1553845852613449 0.33166030049324036
+0.23648154735565186 0.0008699032478034496 0.0034985581878572702 0.06693591177463531 0.4012782573699951
+0.2878616750240326 0.0007461672648787498 0.002282962668687105 0.1249324232339859 0.44225960969924927
+0.3013454079627991 0.002482182579115033 0.003176699625328183 0.1245155856013298 0.4592805504798889
+0.2906251847743988 0.0017646830528974533 0.002919210121035576 0.10537952184677124 0.3034375011920929
+0.2755395770072937 0.00045372036402113736 0.002487937221303582 0.10472822189331055 0.3228422701358795
+0.25876665115356445 0.0011500868713483214 0.00286057754419744 0.16119754314422607 0.34980154037475586
+0.2512240409851074 0.0005720886401832104 0.0023971654009073973 0.06782811135053635 0.37065112590789795
+0.2853317856788635 0.000755232002120465 0.006852480582892895 0.08105269074440002 0.41139882802963257
+0.28180426359176636 0.0005883142002858222 0.006837806198745966 0.08629840612411499 0.4357226490974426
+0.2731192708015442 0.0011122594587504864 0.0026370652485638857 0.10692752152681351 0.394422709941864
+0.2498789280653 0.002360796555876732 0.003162429900839925 0.06179482117295265 0.3542421758174896
+0.24909608066082 0.0006925183697603643 0.0027027148753404617 0.09011407196521759 0.35247287154197693
+0.2573668658733368 0.001638431684114039 0.0018123635090887547 0.06546555459499359 0.4148523509502411
+0.26634731888771057 0.0038962452672421932 0.004107489716261625 0.15145619213581085 0.387103796005249
+0.2746048867702484 0.0014362949877977371 0.002600096631795168 0.2172503024339676 0.3877169191837311
+0.24080170691013336 0.0007047764956951141 0.0030218269675970078 0.09033498167991638 0.32032689452171326
+0.27229565382003784 0.0021502720192074776 0.0038841464556753635 0.08291060477495193 0.4540385901927948
+0.252277672290802 0.0006888287607580423 0.0023755503352731466 0.13631565868854523 0.33973580598831177
+0.25314629077911377 0.015092619694769382 0.003655902110040188 0.12096431106328964 0.2765919864177704
+0.26823800802230835 0.0014397540362551808 0.0034493422135710716 0.09044262766838074 0.4050489366054535
+0.2354586273431778 0.0005859437515027821 0.002224172465503216 0.1429121345281601 0.2801370620727539
+0.26144829392433167 0.0007823980413377285 0.0018900708528235555 0.06492406129837036 0.3742234706878662
+0.25707995891571045 0.0010932039003819227 0.002950547728687525 0.07540048658847809 0.39206743240356445
+0.26983168721199036 0.0008975307573564351 0.005021612159907818 0.08958690613508224 0.38905787467956543
+0.2662755250930786 0.0009161422494798899 0.0034523080103099346 0.0965813621878624 0.3417161703109741
+0.2798925042152405 0.0006925945635885 0.0027014673687517643 0.09141513705253601 0.40865060687065125
+0.24723178148269653 0.00036747593549080193 0.0041498527862131596 0.05730481818318367 0.3097732365131378
+0.2519854009151459 0.0017584029119461775 0.002146892249584198 0.09622428566217422 0.3795848488807678
+0.27064743638038635 0.003582284552976489 0.001588460523635149 0.11878237873315811 0.3148922026157379
+0.26200729608535767 0.0021072214003652334 0.0020505613647401333 0.06777144968509674 0.39024418592453003
+0.23458638787269592 0.00046394698438234627 0.0021758703514933586 0.19650021195411682 0.3445311188697815
+0.22697091102600098 0.001085671246983111 0.0024407310411334038 0.16040559113025665 0.28546231985092163
+0.23313206434249878 0.001813498791307211 0.0021394742652773857 0.11845528334379196 0.31010597944259644
+0.25580430030822754 0.0007738622953183949 0.0018288299906998873 0.10857376456260681 0.34906837344169617
+0.22745276987552643 0.0005903386045247316 0.0027780577074736357 0.06453080475330353 0.34927836060523987
+0.2339339554309845 0.001327533507719636 0.005296588875353336 0.06595946103334427 0.3326915204524994
+0.2445366382598877 0.0007582380203530192 0.005095870234072208 0.08212351053953171 0.33191734552383423
+0.2640429437160492 0.0011491053737699986 0.0023692739196121693 0.14768901467323303 0.30581268668174744
+0.2461032271385193 0.002676230389624834 0.0019201255636289716 0.16905616223812103 0.37182021141052246
+0.2631750702857971 0.0015889324713498354 0.002572258235886693 0.0610433928668499 0.4230974316596985
+0.28127968311309814 0.0005761465872637928 0.0030943681485950947 0.12381517887115479 0.38835111260414124
+0.2978358864784241 0.0007997563807293773 0.0026805391535162926 0.14820139110088348 0.4333884119987488
+0.26048803329467773 0.0008660501916892827 0.0022658579982817173 0.1516004502773285 0.307062566280365
+0.2608875632286072 0.0006605183007195592 0.0032881894148886204 0.14368407428264618 0.3874509334564209
+0.26243796944618225 0.001230366062372923 0.0019738187547773123 0.09285196661949158 0.4154258370399475
+0.22857439517974854 0.001778124482370913 0.00219124392606318 0.1590462476015091 0.3193274438381195
+0.2596502900123596 0.000683382386341691 0.002069650450721383 0.08384343236684799 0.39350074529647827
+0.22911779582500458 0.0008592168451286852 0.0019281243439763784 0.10439813882112503 0.2635991871356964
+0.2830502390861511 0.0015926657943055034 0.002042103325948119 0.08275013417005539 0.3578382730484009
+0.2495991587638855 0.0007913041044957936 0.002404038794338703 0.07068926841020584 0.3837985098361969
+0.25591421127319336 0.0010733111994341016 0.006779937539249659 0.1104516014456749 0.41593414545059204
+0.22929951548576355 0.0009789735777303576 0.008421163074672222 0.08274156600236893 0.24454692006111145
+0.24018973112106323 0.0011064716381952167 0.0028928702231496572 0.06531158834695816 0.2879314422607422
+0.27894163131713867 0.0025164131075143814 0.0032262871973216534 0.18307720124721527 0.413867324590683
+0.23695191740989685 0.0009415476233698428 0.002801113296300173 0.0689164474606514 0.33581405878067017
+0.2718910276889801 0.0008446176070719957 0.002695272443816066 0.08668755739927292 0.46179986000061035
+0.2456529587507248 0.0021672574803233147 0.0023301427718251944 0.08233702182769775 0.27072545886039734
+0.2576579451560974 0.0008083695429377258 0.002409962471574545 0.15922653675079346 0.3508007526397705
+0.27435535192489624 0.0003418608976062387 0.005249626934528351 0.13482822477817535 0.40400055050849915
+0.2682843506336212 0.0005638344446197152 0.002339219441637397 0.10693421214818954 0.41891878843307495
+0.2586421072483063 0.0015869652852416039 0.003309749299660325 0.20382753014564514 0.3562844693660736
+0.2552856206893921 0.0007082725642248988 0.004113164730370045 0.16167676448822021 0.27908870577812195
+0.2558686137199402 0.0013672293862327933 0.0020305085927248 0.07219593226909637 0.2850166857242584
+0.23020735383033752 0.0025279491674154997 0.0028668770100921392 0.10762080550193787 0.36838042736053467
+0.25155073404312134 0.0005067289457656443 0.0022292027715593576 0.06810959428548813 0.3841105103492737
+0.2813200354576111 0.001708409865386784 0.009566727094352245 0.0873197689652443 0.3854186236858368
+0.26400744915008545 0.0011108136968687177 0.00847976841032505 0.15223656594753265 0.35406896471977234
+0.251231849193573 0.001519150915555656 0.0018210988491773605 0.11710473895072937 0.3607075810432434
+0.26411908864974976 0.003688761265948415 0.003846001112833619 0.1266470104455948 0.38263630867004395
+0.24927619099617004 0.0008106062305159867 0.002092375885695219 0.12268485128879547 0.36768072843551636
+0.2470380663871765 0.0018557851435616612 0.0041515701450407505 0.1748218834400177 0.2783536911010742
+0.2549975514411926 0.0026604465674608946 0.004436820279806852 0.09151032567024231 0.345041960477829
+0.2601577341556549 0.0007087431149557233 0.0020862645469605923 0.0975247174501419 0.3806461989879608
+0.2806342542171478 0.0006940121529623866 0.004167230799794197 0.06354505568742752 0.44428151845932007
+0.23238396644592285 0.0007314557442441583 0.0035944480914622545 0.08864574134349823 0.3034357726573944
+0.2636112868785858 0.0005037341616116464 0.0033676859457045794 0.14151594042778015 0.30643394589424133
+0.2577263116836548 0.0005899049574509263 0.0024629621766507626 0.16949713230133057 0.3810041844844818
+0.2652035653591156 0.0008737848256714642 0.0020219602156430483 0.09716146439313889 0.40882793068885803
+0.23925727605819702 0.0011560323182493448 0.0019676396623253822 0.07617314159870148 0.3279682993888855
+0.2423580288887024 0.0008180414442904294 0.0021201197523623705 0.07266039401292801 0.3373015820980072
+0.23311178386211395 0.0005411689635366201 0.002539250999689102 0.09588123857975006 0.2745036482810974
+0.29175227880477905 0.0010813939152285457 0.0021741611417382956 0.15686935186386108 0.5213536620140076
+0.26186561584472656 0.0014727265806868672 0.0022988831624388695 0.08950015157461166 0.37744179368019104
+0.2573225200176239 0.0011226740898564458 0.0026866276748478413 0.09357468038797379 0.3388829827308655
+0.2577534019947052 0.0005236019496805966 0.0015876309480518103 0.06999415159225464 0.3290323317050934
+0.23867195844650269 0.0005438672378659248 0.0013549391878768802 0.09541326016187668 0.36573976278305054
+0.24226614832878113 0.001201254315674305 0.0014481217367574573 0.09143760055303574 0.33819690346717834
+0.2644768953323364 0.0012155306758359075 0.0018527840729802847 0.08386685699224472 0.417983740568161
+0.2575615644454956 0.00050876330351457 0.002153696957975626 0.09832639992237091 0.3713507056236267
+0.25880563259124756 0.0006487731589004397 0.0021265624091029167 0.08066421747207642 0.340080589056015
+0.23415407538414001 0.0012533151311799884 0.0020284124184399843 0.06353232264518738 0.3019426167011261
+0.23551344871520996 0.0013522340450435877 0.002122987760230899 0.07678966224193573 0.39242544770240784
+0.23146861791610718 0.0004916353500448167 0.0027338045183569193 0.09129855036735535 0.343732625246048
+0.23738151788711548 0.0008043382549658418 0.0032568827737122774 0.07436873018741608 0.3211146295070648
+0.2263624370098114 0.0005882251425646245 0.0027293588500469923 0.14190673828125 0.27006804943084717
+0.24355842173099518 0.000563736364711076 0.0019981462974101305 0.0697762668132782 0.3545624911785126
+0.26653867959976196 0.0007426859810948372 0.0021976539865136147 0.12878194451332092 0.3729102611541748
+0.23319341242313385 0.000410756329074502 0.0023079533129930496 0.09128263592720032 0.36852318048477173
+0.2612132430076599 0.0005237004952505231 0.003018875140696764 0.17387031018733978 0.4285533130168915
+0.23280969262123108 0.0005283004138618708 0.0025083008222281933 0.12419073283672333 0.3047053813934326
+0.23248881101608276 0.0004035158781334758 0.003058214206248522 0.08545581251382828 0.3225332498550415
+0.26375913619995117 0.00037518132012337446 0.002572461264207959 0.07794540375471115 0.3793802857398987
+0.25255149602890015 0.0015905610052868724 0.00194833567366004 0.09419720619916916 0.3508002758026123
+0.2211226224899292 0.0015659135533496737 0.00320771848782897 0.11637044697999954 0.2821844518184662
+0.24729490280151367 0.0004316534614190459 0.0024027705658227205 0.13373634219169617 0.3671948313713074
+0.2433411329984665 0.0013892034767195582 0.002572859637439251 0.11530598253011703 0.3323274850845337
+0.24737361073493958 0.0015197695465758443 0.0021424300502985716 0.0787278488278389 0.38950851559638977
+0.2204342931509018 0.0005545276217162609 0.0037469398230314255 0.070783331990242 0.3465750217437744
+0.23031702637672424 0.0007439855253323913 0.005221856292337179 0.14737944304943085 0.2807866930961609
+0.24193188548088074 0.0005037745577283204 0.003012236673384905 0.0634642094373703 0.33248579502105713
+0.22893264889717102 0.0006315347854979336 0.0024473252706229687 0.0511055514216423 0.3468736708164215
+0.2171148657798767 0.0005156947881914675 0.002615673467516899 0.07740164548158646 0.3534966707229614
+0.2616790533065796 0.0005372435552999377 0.002863917965441942 0.12442735582590103 0.3857441842556
+0.2191113531589508 0.00058915582485497 0.0025273344945162535 0.12401910871267319 0.2908116579055786
+0.2500133514404297 0.0003721879329532385 0.002708448562771082 0.06989646703004837 0.38819217681884766
+0.25556331872940063 0.0003953857230953872 0.0029011177830398083 0.11417273432016373 0.3186257779598236
+0.22674067318439484 0.0008359176572412252 0.001691698795184493 0.07486901432275772 0.3212863802909851
+0.2267971932888031 0.0015271743759512901 0.0021787795703858137 0.07891334593296051 0.3527475893497467
+0.2494223415851593 0.0010359697043895721 0.002684373175725341 0.09388724714517593 0.39409899711608887
+0.23158171772956848 0.00048408203292638063 0.002648033434525132 0.0710257887840271 0.33496931195259094
+0.22990183532238007 0.0019972887821495533 0.0018280644435435534 0.13560689985752106 0.2427833378314972
+0.22563481330871582 0.0021466966718435287 0.0016839583404362202 0.0579272024333477 0.3180941045284271
+0.23974990844726562 0.000508370460011065 0.0025055943988263607 0.10735409706830978 0.3940349817276001
+0.22959834337234497 0.0011262426851317286 0.005589182022958994 0.09626860916614532 0.3728175759315491
+0.24661579728126526 0.0009200849453918636 0.004311538301408291 0.0633319690823555 0.3123551607131958
+0.23318904638290405 0.000577587285079062 0.0016523012891411781 0.09710229933261871 0.2634734511375427
+0.22211289405822754 0.001569718704558909 0.002423375379294157 0.07215572893619537 0.2762555778026581
+0.218587264418602 0.0008214263943955302 0.0017637013224884868 0.05565414950251579 0.3772614896297455
+0.230634868144989 0.0008240194874815643 0.0037440871819853783 0.09888053685426712 0.41420841217041016
+0.22719168663024902 0.0011347760446369648 0.003460341366007924 0.1425018012523651 0.2645772099494934
+0.25117915868759155 0.00040135617018677294 0.003584652440622449 0.08010752499103546 0.33639833331108093
+0.23827281594276428 0.0005738253821618855 0.006449609529227018 0.0851316973567009 0.3203079104423523
+0.24803929030895233 0.0005347955739125609 0.003408205695450306 0.05842670425772667 0.34868112206459045
+0.23630720376968384 0.00309454626403749 0.002219713293015957 0.15200622379779816 0.34966692328453064
+0.21897488832473755 0.0025166289415210485 0.004708188120275736 0.10104077309370041 0.3600674867630005
+0.20577271282672882 0.0004980387166142464 0.002697911113500595 0.10865834355354309 0.2481500506401062
+0.24268512427806854 0.0034292510244995356 0.003066265257075429 0.07562011480331421 0.3083934485912323
+0.23840901255607605 0.0024137143045663834 0.002288186689838767 0.04235504940152168 0.34304842352867126
+0.21930737793445587 0.00048335272003896534 0.004744767211377621 0.06988684833049774 0.31693315505981445
+0.2184416502714157 0.0015661247307434678 0.008885333314538002 0.09427535533905029 0.34580543637275696
+0.24538446962833405 0.0005837166681885719 0.005367028061300516 0.06620924919843674 0.3792771100997925
+0.22851136326789856 0.001483757863752544 0.0025291882921010256 0.08841638267040253 0.3147984743118286
+0.23900839686393738 0.00206750282086432 0.003505513072013855 0.06751826405525208 0.3173956573009491
+0.25092658400535583 0.0005137391272000968 0.002732319524511695 0.07731270045042038 0.2849825918674469
+0.23786598443984985 0.001762588508427143 0.0023300594184547663 0.1046234592795372 0.32478970289230347
+0.19451166689395905 0.0023866656702011824 0.0028444926720112562 0.062358640134334564 0.3291065990924835
+0.24501074850559235 0.0008793164161033928 0.0021016087848693132 0.07322868704795837 0.4282229244709015
+0.2517794370651245 0.0006077784346416593 0.0023800123017281294 0.09050600975751877 0.3734007775783539
+0.2564147710800171 0.0020922834519296885 0.0031477059237658978 0.09157025068998337 0.17989684641361237
+0.2322942614555359 0.0008552266517654061 0.0026590724010020494 0.07153813540935516 0.30367404222488403
+0.24065753817558289 0.0005968183977529407 0.0017155379755422473 0.05093948170542717 0.4362390637397766
+0.26033324003219604 0.00199478049762547 0.006145987659692764 0.08140617609024048 0.46786415576934814
+0.20813730359077454 0.0008474843925796449 0.005968482233583927 0.1329604685306549 0.2643436789512634
+0.24156662821769714 0.001784417312592268 0.0024380646646022797 0.07545896619558334 0.22388331592082977
+0.2207660675048828 0.0040503800846636295 0.005495704244822264 0.07229414582252502 0.32197320461273193
+0.25520408153533936 0.00047998130321502686 0.0021106365602463484 0.06091078370809555 0.46852487325668335
+0.21995283663272858 0.0034565969835966825 0.006692843046039343 0.08979448676109314 0.3279462456703186
+0.20635583996772766 0.0033296747133135796 0.01685202494263649 0.13378101587295532 0.27401018142700195
+0.21949976682662964 0.0006060304585844278 0.0023793599102646112 0.0841304287314415 0.28956466913223267
+0.24159979820251465 0.0021323678083717823 0.005880401469767094 0.1650771051645279 0.29767510294914246
+0.23832035064697266 0.0005149041535332799 0.002008510287851095 0.058263782411813736 0.3722272217273712
+0.23191651701927185 0.002031083917245269 0.006646658759564161 0.09566590189933777 0.38467028737068176
+0.22049462795257568 0.0019604831468313932 0.009562098421156406 0.13510645925998688 0.2844628691673279
+0.223982572555542 0.000934379524551332 0.0022367527708411217 0.08145797997713089 0.3068484365940094
+0.18957027792930603 0.00440181465819478 0.004047345370054245 0.04039612412452698 0.20985932648181915
+0.19502907991409302 0.0016876156441867352 0.00209114165045321 0.04781252145767212 0.24914544820785522
+0.22839444875717163 0.0013820661697536707 0.004698650445789099 0.06835579127073288 0.35689282417297363
+0.23395249247550964 0.003260406432673335 0.007267057429999113 0.14716392755508423 0.29518136382102966
+0.2159835547208786 0.0007511304575018585 0.002387553919106722 0.08326732367277145 0.2575536072254181
+0.26589035987854004 0.001697456231340766 0.004972507245838642 0.09552288800477982 0.3991175889968872
+0.22763895988464355 0.0014971720520406961 0.0053520845249295235 0.08271332085132599 0.36876243352890015
+0.23821283876895905 0.000716963957529515 0.002041959436610341 0.09264995902776718 0.3418855369091034
+0.20788952708244324 0.0027431969065219164 0.0023961029946804047 0.0922415554523468 0.25949978828430176
+0.22194001078605652 0.00171008484903723 0.0017176290275529027 0.09057828038930893 0.3017001152038574
+0.21026593446731567 0.00041786092333495617 0.002948411973193288 0.13568651676177979 0.2676490545272827
+0.2699544131755829 0.000520921079441905 0.0021022311411798 0.0697096660733223 0.4577009975910187
+0.19424551725387573 0.0005878239753656089 0.0030459801200777292 0.07297523319721222 0.25672173500061035
+0.2546575665473938 0.0006697116186842322 0.004453422501683235 0.09285978227853775 0.32845285534858704
+0.21658840775489807 0.0043630460277199745 0.0024572003167122602 0.07003895193338394 0.28827622532844543
+0.23335927724838257 0.002405735896900296 0.0016664456343278289 0.08887358754873276 0.38700830936431885
+0.2096114456653595 0.0009360542171634734 0.002043592743575573 0.06304081529378891 0.34178903698921204
+0.2127285897731781 0.0008731084526516497 0.0035344904754310846 0.0970228910446167 0.2906738817691803
+0.22003433108329773 0.001200141734443605 0.0027965272311121225 0.08992695063352585 0.19830304384231567
+0.2219245433807373 0.0003933207190129906 0.0019327531335875392 0.060600873082876205 0.3278566002845764
+0.21543312072753906 0.0005579421413131058 0.0024659952614456415 0.05322789400815964 0.36389297246932983
+0.21984288096427917 0.0006693715113215148 0.0016969835851341486 0.08159011602401733 0.33750802278518677
+0.205206960439682 0.0011957744136452675 0.0027569246012717485 0.09650497138500214 0.23389498889446259
+0.2501460909843445 0.0004842414637096226 0.0019760928116738796 0.08570502698421478 0.2884354889392853
+0.23579832911491394 0.0011415143962949514 0.002867815550416708 0.06279296427965164 0.4050009548664093
+0.2286199927330017 0.0006465784390456975 0.002015927108004689 0.06908873468637466 0.3929128646850586
+0.2075832188129425 0.0010357409482821822 0.002709583379328251 0.09202215820550919 0.23867391049861908
+0.23509590327739716 0.0013859387254342437 0.0034398019779473543 0.09174078702926636 0.35113289952278137
+0.1991022229194641 0.0004635945660993457 0.0018718740902841091 0.06491687148809433 0.2768210470676422
+0.19632107019424438 0.00132784154266119 0.0019358198624104261 0.05042165145277977 0.309699684381485
+0.20982645452022552 0.0010336017003282905 0.0015751307364553213 0.07241754978895187 0.3061469495296478
+0.23533552885055542 0.0005427406285889447 0.002291979268193245 0.19091211259365082 0.3143150210380554
+0.21247410774230957 0.0007523114909417927 0.00268893176689744 0.07570141553878784 0.24416504800319672
+0.21545633673667908 0.0004026549868285656 0.001396616455167532 0.055971503257751465 0.32586899399757385
+0.24101081490516663 0.0005416497006081045 0.001653736107982695 0.049421072006225586 0.4297381341457367
+0.21543484926223755 0.00041718085412867367 0.0016592440661042929 0.12290230393409729 0.29673969745635986
+0.24891769886016846 0.0004937081248499453 0.002214213600382209 0.09414118528366089 0.35925447940826416
+0.21481788158416748 0.0004289864154998213 0.002028330462053418 0.08470100909471512 0.2961166799068451
+0.1838744878768921 0.0006715465569868684 0.0023826456163078547 0.08003120869398117 0.2556077539920807
+0.199642151594162 0.00039334726170636714 0.0021423515863716602 0.05309338867664337 0.32119810581207275
+0.19564031064510345 0.0010269810445606709 0.0017023551044985652 0.07310651987791061 0.28221243619918823
+0.22257089614868164 0.001567829749546945 0.0028991070576012135 0.08465972542762756 0.3232121467590332
+0.21675406396389008 0.0004367668298073113 0.001965954201295972 0.07033798843622208 0.26892808079719543
+0.19726499915122986 0.0012032256927341223 0.0018229822162538767 0.04871704801917076 0.30563244223594666
+0.21886375546455383 0.0013633986236527562 0.0017585244495421648 0.05515677109360695 0.3369535207748413
+0.23273499310016632 0.00038859559572301805 0.0017852376913651824 0.055463556200265884 0.3235546946525574
+0.21020852029323578 0.0007923449156805873 0.003424813039600849 0.08740685880184174 0.2553125023841858
+0.21804124116897583 0.00046023211325518787 0.001762455445714295 0.06888851523399353 0.24708916246891022
+0.21216422319412231 0.0004481537325773388 0.0016297743422910571 0.04744968190789223 0.33724695444107056
+0.2207093983888626 0.00042249387479387224 0.0015260478248819709 0.059861794114112854 0.38940200209617615
+0.2168968766927719 0.0005349242710508406 0.0028186298441141844 0.11907509714365005 0.34036436676979065
+0.20514747500419617 0.0006179200136102736 0.0023389605339616537 0.11710487306118011 0.20591871440410614
+0.22092221677303314 0.0005709226243197918 0.0021072044037282467 0.09284524619579315 0.24233029782772064
+0.22403192520141602 0.0007367050857283175 0.0031720937695354223 0.05305255576968193 0.320139080286026
+0.2310800701379776 0.000655416224617511 0.0012533366680145264 0.049065638333559036 0.38748204708099365
+0.20366686582565308 0.002971487818285823 0.00713342847302556 0.11135696619749069 0.31001028418540955
+0.22094625234603882 0.0009648131090216339 0.00751853734254837 0.11577039211988449 0.23337361216545105
+0.22575417160987854 0.002617690246552229 0.0030319769866764545 0.05046594515442848 0.29757770895957947
+0.23684349656105042 0.005048014223575592 0.006349519826471806 0.0444805845618248 0.3556676506996155
+0.2130947709083557 0.0005384165560826659 0.002131665823981166 0.048603251576423645 0.3823287785053253
+0.21905620396137238 0.004050714895129204 0.004676970653235912 0.10024300217628479 0.31575512886047363
+0.21022269129753113 0.004484749399125576 0.0038948997389525175 0.08835220336914062 0.19811812043190002
+0.20390981435775757 0.0004621602420229465 0.0023560398258268833 0.04445436969399452 0.29575401544570923
+0.24602091312408447 0.0017064822604879737 0.004575085360556841 0.06756264716386795 0.45000699162483215
+0.22708876430988312 0.0007374509004876018 0.001758641330525279 0.09718767553567886 0.37461626529693604
+0.227990984916687 0.0007811608375050128 0.004566363524645567 0.166792631149292 0.32762256264686584
+0.23396623134613037 0.000508168013766408 0.0028034280985593796 0.10001040995121002 0.24478314816951752
+0.2426021844148636 0.0007804064080119133 0.0018204160733148456 0.07152242958545685 0.32909828424453735
+0.1868799775838852 0.000630362774245441 0.0022306595928967 0.0577792227268219 0.2528305947780609
+0.21569699048995972 0.0010100000072270632 0.001691656420007348 0.048905353993177414 0.33711904287338257
+0.21981537342071533 0.0027524782344698906 0.008924763649702072 0.11398536711931229 0.35754063725471497
+0.22528591752052307 0.0005009844317100942 0.006849868223071098 0.10745760053396225 0.34994927048683167
+0.21012189984321594 0.00442857388406992 0.001997484127059579 0.10582578182220459 0.27749714255332947
+0.22580672800540924 0.007431477773934603 0.00499830674380064 0.07660531252622604 0.24728432297706604
+0.19887994229793549 0.0010738441487774253 0.0020135482773184776 0.05154827609658241 0.2734725773334503
+0.2172602117061615 0.0047253514640033245 0.004241548478603363 0.06422057747840881 0.32313719391822815
+0.22306938469409943 0.008223090320825577 0.008852888830006123 0.08787917345762253 0.3769964575767517
+0.20328280329704285 0.0011094107758253813 0.0035885637626051903 0.08370691537857056 0.28298419713974
+0.22104288637638092 0.004256453365087509 0.007321384735405445 0.06138908490538597 0.2297276109457016
+0.20802649855613708 0.00675062695518136 0.011081959120929241 0.08189229667186737 0.26442858576774597
+0.2218683362007141 0.0005382297676987946 0.002833469770848751 0.05574001744389534 0.32774102687835693
+0.21248692274093628 0.005784912034869194 0.008291651494801044 0.09527960419654846 0.34034696221351624
+0.24728043377399445 0.005877931602299213 0.007019898388534784 0.1501769721508026 0.3951035141944885
+0.21185846626758575 0.00043452231329865754 0.0025612080935388803 0.10121120512485504 0.26674896478652954
+0.21880371868610382 0.0031800484284758568 0.007202543318271637 0.07042930275201797 0.26774218678474426
+0.20917564630508423 0.0017787751276046038 0.0036315806210041046 0.04309045150876045 0.2993251383304596
+0.2064039409160614 0.0010088460985571146 0.0035151129122823477 0.10103122889995575 0.3245733976364136
+0.24212954938411713 0.002736428752541542 0.006694257725030184 0.07518018037080765 0.4043455421924591
+0.179695725440979 0.0009135175496339798 0.0020992334466427565 0.06288542598485947 0.2334173321723938
+0.21877530217170715 0.0014025981072336435 0.0019596037454903126 0.05976659432053566 0.3027632236480713
+0.22665125131607056 0.0017387168481945992 0.0018203852232545614 0.057509250938892365 0.3240306079387665
+0.23216016590595245 0.0006448225467465818 0.0022652167826890945 0.06023111194372177 0.34532079100608826
+0.22601890563964844 0.0011675601126626134 0.0076416730880737305 0.12813982367515564 0.32922500371932983
+0.18544279038906097 0.0004979726509191096 0.004142585676163435 0.07859641313552856 0.2650529742240906
+0.2246406078338623 0.0018507672939449549 0.0019034907454624772 0.05174707621335983 0.33613404631614685
+0.19723078608512878 0.0018608580576255918 0.002986195730045438 0.07063812017440796 0.23868511617183685
+0.2079811841249466 0.0011879362864419818 0.001799862366169691 0.09633245319128036 0.22025389969348907
+0.2043270468711853 0.00230266316793859 0.0037429253570735455 0.06867115944623947 0.3013949692249298
+0.21433287858963013 0.0019113834714516997 0.0040948959067463875 0.06145511567592621 0.36958926916122437
+0.19797693192958832 0.000396564049879089 0.00295889750123024 0.09498453885316849 0.2912811040878296
+0.19205904006958008 0.0016027726233005524 0.003769693663343787 0.09975238889455795 0.2699977457523346
+0.20004533231258392 0.0013842078624293208 0.004137351643294096 0.07989033311605453 0.20295016467571259
+0.18913188576698303 0.0004386816581245512 0.002238370943814516 0.05938571318984032 0.24820803105831146
+0.22173729538917542 0.0019867343362420797 0.002514017280191183 0.05403658002614975 0.3335840106010437
+0.1833707094192505 0.0025828240904957056 0.0030473475344479084 0.0955783948302269 0.2957286536693573
+0.22416529059410095 0.0006454195827245712 0.002248261123895645 0.06345130503177643 0.3675069510936737
+0.20774543285369873 0.0011048369342461228 0.002112717367708683 0.0766640156507492 0.2545175552368164
+0.2004261165857315 0.0026957772206515074 0.002231963910162449 0.0619775764644146 0.2332981377840042
+0.22336135804653168 0.0013170641614124179 0.0015132864937186241 0.07900169491767883 0.3321230411529541
+0.22220978140830994 0.000608286471106112 0.0031038036104291677 0.060098618268966675 0.40061166882514954
+0.1993405520915985 0.0016004317440092564 0.0053962161764502525 0.08208193629980087 0.29479116201400757
+0.20203281939029694 0.0005893819034099579 0.002153361914679408 0.07615619152784348 0.2010779082775116
+0.1925641894340515 0.0008982146973721683 0.002909241011366248 0.04875947907567024 0.15720829367637634
+0.25694525241851807 0.0008942834101617336 0.003474028315395117 0.05433591827750206 0.45740699768066406
+0.1902187615633011 0.0007391502731479704 0.0022243994753807783 0.08634019643068314 0.33346131443977356
+0.22076532244682312 0.001975528197363019 0.005198531784117222 0.11480817198753357 0.29353269934654236
+0.22614584863185883 0.000520376896020025 0.0017190836369991302 0.11101032793521881 0.19206589460372925
+0.19554755091667175 0.0017800641944631934 0.005759973544627428 0.050630588084459305 0.25758764147758484
+0.22693736851215363 0.0006971246330067515 0.0035396001767367125 0.03987140581011772 0.40307289361953735
+0.21833863854408264 0.0023800276685506105 0.0043370905332267284 0.06609099358320236 0.32194095849990845
+0.2192041575908661 0.0038639495614916086 0.012478657998144627 0.1280595362186432 0.2958652079105377
+0.196834534406662 0.000503536663018167 0.0026459763757884502 0.08334704488515854 0.2428147941827774
+0.19076670706272125 0.005916335619986057 0.006852713413536549 0.06807626038789749 0.23830468952655792
+0.2077309787273407 0.002662482438609004 0.004492562264204025 0.061083875596523285 0.24383306503295898
+0.17479155957698822 0.002121684607118368 0.005417428445070982 0.06552664190530777 0.2541830539703369
+0.25512880086898804 0.007112216670066118 0.019859550520777702 0.1020432561635971 0.4572162628173828
+0.19282980263233185 0.0010232445783913136 0.008964277803897858 0.09710857272148132 0.31894874572753906
+0.2246493697166443 0.005329366307705641 0.004554344806820154 0.07721804082393646 0.2867550551891327
+0.21557503938674927 0.009528137743473053 0.012767881155014038 0.0627431720495224 0.16326949000358582
+0.20763547718524933 0.0006146997911855578 0.004600762855261564 0.049313053488731384 0.323306143283844
+0.2167024463415146 0.010159925557672977 0.003969253972172737 0.06501264125108719 0.39457792043685913
+0.2022755742073059 0.014526675455272198 0.010504398494958878 0.11559395492076874 0.3259437084197998
+0.22387203574180603 0.001431892509572208 0.0019820122979581356 0.07992663234472275 0.29900267720222473
+0.25975459814071655 0.006672917399555445 0.010188763029873371 0.11813525855541229 0.2454766482114792
+0.21718406677246094 0.008167851716279984 0.010608622804284096 0.05845221132040024 0.3028169870376587
+0.21877625584602356 0.0005369066493585706 0.0024067515041679144 0.08942501246929169 0.40238094329833984
+0.22039464116096497 0.008776180446147919 0.022498739883303642 0.13290049135684967 0.29262933135032654
+0.23031464219093323 0.004011481534689665 0.01769816316664219 0.11437887698411942 0.32508373260498047
+0.17709779739379883 0.0029742533806711435 0.0021628765389323235 0.054699428379535675 0.25691521167755127
+0.1941264420747757 0.012049611657857895 0.00914386473596096 0.05737749859690666 0.2369348704814911
+0.20094376802444458 0.0032681438606232405 0.0035964325070381165 0.09006217122077942 0.2675212323665619
+0.2184653878211975 0.0034795962274074554 0.006365037988871336 0.09176220744848251 0.2556456923484802
+0.2108231484889984 0.00849709752947092 0.011318338103592396 0.08892860263586044 0.28648513555526733
+0.2066250890493393 0.0018432113574817777 0.00288993027061224 0.06011989340186119 0.33809423446655273
+0.23900817334651947 0.001992492936551571 0.004593128338456154 0.04756201058626175 0.3827938139438629
+0.19327911734580994 0.003976329229772091 0.0066298553720116615 0.055622827261686325 0.24768328666687012
+0.19213661551475525 0.0009530197712592781 0.003337726229801774 0.08748012036085129 0.23898296058177948
+0.18754182755947113 0.0014542876742780209 0.0018132200930267572 0.05668826401233673 0.23015014827251434
+0.18328829109668732 0.003959920722991228 0.0014008604921400547 0.05029560625553131 0.29085272550582886
+0.20573151111602783 0.0031731596682220697 0.00154836755245924 0.10655825585126877 0.3325992226600647
+0.16759327054023743 0.0006387080647982657 0.002295947866514325 0.0636124461889267 0.22009199857711792
+0.17334935069084167 0.0016206795116886497 0.002039699349552393 0.07832205295562744 0.2227916121482849
+0.20338231325149536 0.004422422032803297 0.0020451839081943035 0.07607685029506683 0.28753232955932617
+0.2019909918308258 0.0029454620089381933 0.0015933114336803555 0.058992158621549606 0.32946500182151794
+0.19000419974327087 0.0004896630416624248 0.001872647670097649 0.10210441797971725 0.23301304876804352
+0.19814705848693848 0.001459667575545609 0.002052144380286336 0.06851522624492645 0.2075771540403366
+0.18823307752609253 0.0017175673274323344 0.0013847250957041979 0.05140714719891548 0.17557506263256073
+0.2030714750289917 0.0006807460449635983 0.0010977936908602715 0.07575923949480057 0.3330537676811218
+0.18706347048282623 0.0004176959628239274 0.0016624487470835447 0.04704708606004715 0.3099677264690399
+0.18413996696472168 0.0006041937740519643 0.0029805416706949472 0.09041017293930054 0.282552570104599
+0.18260744214057922 0.001151994918473065 0.002499211812391877 0.10032246261835098 0.2528739869594574
+0.25038856267929077 0.0013527352130040526 0.0034255231730639935 0.055840615183115005 0.24043096601963043
+0.22849337756633759 0.000429160165367648 0.0027873478829860687 0.0476110577583313 0.31822940707206726
+0.19691911339759827 0.0025068146642297506 0.0017551443306729198 0.059214480221271515 0.33573824167251587
+0.18941780924797058 0.0051008774898946285 0.005266855005174875 0.07153965532779694 0.3286892771720886
+0.22120019793510437 0.0013479484478011727 0.005304969847202301 0.1352546513080597 0.31691062450408936
+0.19521768391132355 0.002310070674866438 0.0031739133410155773 0.09385818988084793 0.23081310093402863
+0.19705595076084137 0.006911729462444782 0.008065471425652504 0.06619853526353836 0.22197818756103516
+0.19833225011825562 0.0014810997527092695 0.0037065010983496904 0.09457598626613617 0.26798659563064575
+0.20051254332065582 0.0033937592525035143 0.0025580127257853746 0.058402612805366516 0.2889859080314636
+0.19526103138923645 0.009674862958490849 0.010857516899704933 0.065342016518116 0.25085464119911194
+0.20918473601341248 0.0023413519375026226 0.0065863775089383125 0.07312174886465073 0.302970826625824
+0.19822125136852264 0.00308018340729177 0.001853420864790678 0.06672240048646927 0.259408563375473
+0.23907652497291565 0.010900532826781273 0.003318325150758028 0.04112887382507324 0.3947494626045227
+0.1877903938293457 0.0049280570819973946 0.002365435240790248 0.07589466124773026 0.24966798722743988
+0.21661242842674255 0.0005425841663964093 0.0038417535834014416 0.10776980966329575 0.2895177900791168
+0.17939339578151703 0.002311090938746929 0.002901904983446002 0.06480784714221954 0.23456576466560364
+0.1853112131357193 0.002047725021839142 0.0011891707545146346 0.047235406935214996 0.23439134657382965
+0.17449061572551727 0.000879111175891012 0.002739881630986929 0.04253622889518738 0.25149112939834595
+0.2013224959373474 0.0007909424020908773 0.0022093479055911303 0.06123248115181923 0.2933046519756317
+0.18577012419700623 0.0004333110118750483 0.00239490601234138 0.08769621700048447 0.2628776431083679
+0.17650559544563293 0.00042364763794466853 0.0026278886944055557 0.13252651691436768 0.22199760377407074
+0.15996766090393066 0.0006095392745919526 0.002979528158903122 0.046079397201538086 0.23709933459758759
+0.22861380875110626 0.00032782970811240375 0.0014689437812194228 0.04350364953279495 0.3887456953525543
+0.20512281358242035 0.001822342281229794 0.0026565773878246546 0.06095400080084801 0.2505946457386017
+0.1913076639175415 0.0015607208479195833 0.004715930204838514 0.07075690478086472 0.2220247983932495
+0.21107885241508484 0.0004933580057695508 0.0020889551378786564 0.10341832786798477 0.3267001211643219
+0.20264866948127747 0.00289524020627141 0.0018149196403101087 0.04431721195578575 0.355938196182251
+0.20583081245422363 0.0024889837950468063 0.0018947827629745007 0.09147012233734131 0.3411584794521332
+0.21449968218803406 0.0005703803617507219 0.0030124548356980085 0.08249531686306 0.22441230714321136
+0.20870983600616455 0.0006467962521128356 0.0022864886559545994 0.06541036814451218 0.21525117754936218
+0.19606582820415497 0.0008176693809218705 0.0014240799937397242 0.04020976275205612 0.30558842420578003
+0.173350989818573 0.0009916970739141107 0.001726810005493462 0.042119454592466354 0.27431198954582214
+0.1819508969783783 0.0009755735518410802 0.0034405996557325125 0.10253402590751648 0.2844439446926117
+0.20293906331062317 0.00046786380698904395 0.003404967486858368 0.08979768306016922 0.3061332404613495
+0.21242289245128632 0.0017282834742218256 0.0034972515422850847 0.09338576346635818 0.24836379289627075
+0.19890053570270538 0.0021804682910442352 0.0055534797720611095 0.09326387196779251 0.24656236171722412
+0.20556268095970154 0.0003802514693234116 0.0026092403568327427 0.0511513315141201 0.32208144664764404
+0.24229347705841064 0.004622759763151407 0.0023947821464389563 0.057916633784770966 0.4169482886791229
+0.21514448523521423 0.006317009683698416 0.007043226156383753 0.06922519207000732 0.24561496078968048
+0.18918323516845703 0.0006724336999468505 0.0026665341574698687 0.07259715348482132 0.21064604818820953
+0.17527613043785095 0.0044716726988554 0.0032509679440408945 0.05287019908428192 0.24283838272094727
+0.20906969904899597 0.00629205210134387 0.003952388186007738 0.04444187507033348 0.4028310477733612
+0.22097055613994598 0.0007807666552253067 0.0019995674956589937 0.05873310565948486 0.3885928988456726
+0.19116216897964478 0.0026621134020388126 0.007156746461987495 0.1518244445323944 0.206899031996727
+0.21706023812294006 0.0018497150158509612 0.003122999332845211 0.09635867923498154 0.17688751220703125
+0.20326779782772064 0.0005487778689712286 0.0021228110417723656 0.05326457321643829 0.2662808895111084
+0.23262375593185425 0.0008220482268370688 0.0027295073959976435 0.055992987006902695 0.4305925965309143
+0.16580533981323242 0.0007512754527851939 0.0032199383713304996 0.0547901876270771 0.27501583099365234
+0.19265049695968628 0.0016391667304560542 0.010648016817867756 0.13109880685806274 0.23593491315841675
+0.19638395309448242 0.0008155275718308985 0.003710119053721428 0.10733382403850555 0.18873195350170135
+0.20209252834320068 0.004927534144371748 0.0069603826850652695 0.08910276740789413 0.2091565877199173
+0.1966785490512848 0.0023906847927719355 0.007333244197070599 0.04012328386306763 0.30128565430641174
+0.18268117308616638 0.002572418423369527 0.0017709090607240796 0.09788474440574646 0.30058690905570984
+0.20857329666614532 0.010759969241917133 0.012718413956463337 0.12946957349777222 0.33952733874320984
+0.19524142146110535 0.0035212738439440727 0.006021957378834486 0.11795622110366821 0.25324639678001404
+0.17182035744190216 0.006366886664181948 0.00465622590854764 0.05787637084722519 0.1942201405763626
+0.20252999663352966 0.00997820682823658 0.015748631209135056 0.052471138536930084 0.27001968026161194
+0.20522165298461914 0.0014608564088121057 0.0040843249298632145 0.04017597809433937 0.28778812289237976
+0.21508589386940002 0.005165749695152044 0.00517694978043437 0.1202438622713089 0.3152904808521271
+0.16606947779655457 0.009964997880160809 0.011425366625189781 0.1180199533700943 0.18807744979858398
+0.18333700299263 0.0009706969722174108 0.00168650783598423 0.08615124970674515 0.2653166353702545
+0.19402539730072021 0.0038971821777522564 0.005535148084163666 0.06423470377922058 0.227275088429451
+0.1927187442779541 0.004526787903159857 0.004195203073322773 0.03976970165967941 0.2601501941680908
+0.18847694993019104 0.0006399896228685975 0.003262113779783249 0.12630359828472137 0.2839818298816681
+0.18721672892570496 0.0022909112740308046 0.009649563580751419 0.11726220697164536 0.2838229835033417
+0.18256719410419464 0.0010847499361261725 0.00373904500156641 0.09235206246376038 0.25529277324676514
+0.15337501466274261 0.0009376727393828332 0.001868558581918478 0.061646778136491776 0.19165174663066864
+0.21061566472053528 0.0013921110657975078 0.0033712449949234724 0.07539259642362595 0.22243237495422363
+0.19016823172569275 0.0004799403250217438 0.0011874380288645625 0.06176908686757088 0.22552363574504852
+0.1868821680545807 0.0023289977107197046 0.004903663881123066 0.0788780152797699 0.2988845705986023
+0.1941179484128952 0.0015035076066851616 0.005478620063513517 0.08272808790206909 0.28218474984169006
+0.176103413105011 0.0009200883214361966 0.00205536768771708 0.09014077484607697 0.2351895272731781
+0.19644305109977722 0.003740342566743493 0.0048940470442175865 0.04809332638978958 0.2467358410358429
+0.19656318426132202 0.0014313174178823829 0.003756643505766988 0.09322237968444824 0.2003195732831955
+0.17728185653686523 0.0014893902698531747 0.001453037024475634 0.055401191115379333 0.2929423451423645
+0.20923325419425964 0.006123406812548637 0.00596284493803978 0.06958652287721634 0.3893287181854248
+0.17707693576812744 0.002480918075889349 0.004454122856259346 0.10907136648893356 0.23865580558776855
+0.1986202895641327 0.0011101223062723875 0.001993095502257347 0.07961084693670273 0.2050623744726181
+0.19169528782367706 0.005179793573915958 0.007290225010365248 0.04481034725904465 0.2765125334262848
+0.17808540165424347 0.0014960648259148002 0.0029381681233644485 0.061843231320381165 0.2548682689666748
+0.19465835392475128 0.002011362463235855 0.0025118673220276833 0.05838179215788841 0.2759605050086975
+0.17466506361961365 0.004342892672866583 0.005597845185548067 0.08405805379152298 0.25149187445640564
+0.17464125156402588 0.0012516265269368887 0.002116329735144973 0.06812676042318344 0.20472939312458038
+0.16491252183914185 0.001180840190500021 0.0021919053979218006 0.042305298149585724 0.2252264767885208
+0.18680784106254578 0.0024166093207895756 0.0026259711012244225 0.05490221083164215 0.2322641909122467
+0.16556215286254883 0.0007346639176830649 0.0022669481113553047 0.05104292929172516 0.2505364716053009
+0.15101337432861328 0.0010637118248268962 0.006495238747447729 0.07132977992296219 0.20919014513492584
+0.16347697377204895 0.0009579146862961352 0.00596799748018384 0.07043256610631943 0.23710723221302032
+0.167390376329422 0.0006863757153041661 0.002467653714120388 0.06254720687866211 0.21374210715293884
+0.1679707169532776 0.0016462343046441674 0.0028676383662968874 0.03834732621908188 0.2191035896539688
+0.18084168434143066 0.0006047971546649933 0.001803375082090497 0.03950273245573044 0.2296421378850937
+0.2049783319234848 0.0011878160294145346 0.001974563580006361 0.12880578637123108 0.2414359599351883
+0.17406070232391357 0.0020888473372906446 0.0036891712807118893 0.1080808937549591 0.24765627086162567
+0.1928522139787674 0.000579409534111619 0.0020374299492686987 0.0637841522693634 0.3002091944217682
+0.15813612937927246 0.0008838649955578148 0.003432223107665777 0.05956117808818817 0.19503876566886902
+0.16630306839942932 0.001234398572705686 0.005037613678723574 0.09559988975524902 0.19762349128723145
+0.20016393065452576 0.00034243270056322217 0.002103229984641075 0.06040268763899803 0.3047569990158081
+0.1721828579902649 0.002494164975360036 0.00206856825388968 0.06952478736639023 0.24833567440509796
+0.19407179951667786 0.002545727416872978 0.0024549381341785192 0.07082357257604599 0.2954796552658081
+0.20812419056892395 0.00040378107223659754 0.0011003789259120822 0.06266710162162781 0.2999178171157837
+0.17034773528575897 0.0013753515668213367 0.0020416302140802145 0.0538557767868042 0.25258970260620117
+0.18381226062774658 0.0015327278524637222 0.0012612269492819905 0.0968410074710846 0.26188015937805176
+0.19073137640953064 0.000527448020875454 0.0013030018890276551 0.09973373264074326 0.24291789531707764
+0.18666373193264008 0.0008076004451140761 0.00280898273922503 0.060418471693992615 0.2741214334964752
+0.1781979501247406 0.0008249682141467929 0.0029681501910090446 0.04613799229264259 0.26002037525177
+0.18165148794651031 0.0003817721735686064 0.0016786452615633607 0.07788878679275513 0.22609154880046844
+0.18026435375213623 0.0009293273906223476 0.0013107431586831808 0.051638178527355194 0.20893731713294983
+0.15498057007789612 0.0013156741624698043 0.0011400337098166347 0.047992244362831116 0.21892794966697693
+0.16310438513755798 0.000770623330026865 0.0016416787402704358 0.05622048303484917 0.19249477982521057
+0.1663040816783905 0.0004748367646243423 0.028170939534902573 0.05337999388575554 0.2785767912864685
+0.1837165653705597 0.0005780098144896328 0.0020676208660006523 0.09493926167488098 0.2456657886505127
+0.164372980594635 0.0005466979346238077 0.0018354617059230804 0.054894328117370605 0.19013604521751404
+0.18532997369766235 0.0003993531863670796 0.002190452069044113 0.04253266006708145 0.3133891522884369
+0.17628289759159088 0.000366643478628248 0.002454014727845788 0.05686517432332039 0.26245447993278503
+0.17319315671920776 0.0005000433302484453 0.002022570464760065 0.05996658653020859 0.24696891009807587
+0.15892773866653442 0.0006641576183028519 0.0022527799010276794 0.06642874330282211 0.1929662674665451
+0.18710573017597198 0.0005196607671678066 0.001871318556368351 0.05973614379763603 0.282416969537735
+0.17279237508773804 0.00048289966071024537 0.001752549083903432 0.07494629174470901 0.22464558482170105
+0.16016799211502075 0.0004500410286709666 0.0017898066435009241 0.05250254645943642 0.23433028161525726
+0.1761864423751831 0.00037701104884035885 0.0020219869911670685 0.055546097457408905 0.2598409354686737
+0.16607627272605896 0.00044968255679123104 0.0018358345841988921 0.05984276533126831 0.2115275263786316
+0.213291198015213 0.0006491054082289338 0.0014706585789099336 0.04900344833731651 0.32589244842529297
+0.16936349868774414 0.0005953487125225365 0.0012290703598409891 0.04522907733917236 0.21646666526794434
+0.19224758446216583 0.0004409611865412444 0.0017674652626737952 0.03982824087142944 0.2901735305786133
+0.16578511893749237 0.00046719168312847614 0.0016261927084997296 0.045962244272232056 0.22620810568332672
+0.17512866854667664 0.0004419292963575572 0.0023093183990567923 0.04738125950098038 0.2029232531785965
+0.14787104725837708 0.0004282344598323107 0.001973031787201762 0.08114778995513916 0.16249947249889374
+0.20089931786060333 0.0005479777464643121 0.0017525125294923782 0.039794959127902985 0.3631784915924072
+0.18247924745082855 0.00044879072811454535 0.002035373356193304 0.05610937997698784 0.31612464785575867
+0.17372605204582214 0.00033800111850723624 0.0023582493886351585 0.07917391508817673 0.20503996312618256
+0.18096891045570374 0.00039262050995603204 0.002047688467428088 0.06408791989088058 0.20546592772006989
+0.155534565448761 0.00036819453816860914 0.0021486692130565643 0.041831012815237045 0.21849700808525085
+0.18881791830062866 0.0006209084531292319 0.0014024402480572462 0.03908637538552284 0.27875566482543945
+0.17286960780620575 0.0016627770382910967 0.001992312725633383 0.05413941666483879 0.25803810358047485
+0.18049116432666779 0.001043768716044724 0.0032776049338281155 0.07557062804698944 0.2544616460800171
+0.17610862851142883 0.0007119233487173915 0.0023148043546825647 0.10064413398504257 0.2147865742444992
+0.15809139609336853 0.0028708730824291706 0.003299601376056671 0.05699124559760094 0.19123515486717224
+0.1850052773952484 0.0012993754353374243 0.0024314397014677525 0.05682598799467087 0.24561433494091034
+0.1623600870370865 0.0009771476034075022 0.0022217154037207365 0.04357085004448891 0.2270749807357788
+0.15561699867248535 0.0034828793723136187 0.006840407848358154 0.060871466994285583 0.1968003511428833
+0.1748172640800476 0.0010467814281582832 0.0034244356211274862 0.09514614939689636 0.2043413519859314
+0.1790066659450531 0.0013630457688122988 0.0018641870701685548 0.08189883083105087 0.23033036291599274
+0.16190606355667114 0.0030947031918913126 0.003651921171694994 0.034181829541921616 0.18678158521652222
+0.1634342521429062 0.0007584153790958226 0.0016377917490899563 0.04730567708611488 0.24843142926692963
+0.1792086958885193 0.0012502140598371625 0.0029830005951225758 0.07382719963788986 0.25390130281448364
+0.18003131449222565 0.0018708965508267283 0.0024581989273428917 0.06824826449155807 0.2611769437789917
+0.17114302515983582 0.0005152505473233759 0.0019220011308789253 0.06814093887805939 0.20126301050186157
+0.1547117531299591 0.0006438480340875685 0.0030667511746287346 0.07633183151483536 0.19163750112056732
+0.1587640345096588 0.0004066737601533532 0.0014231106033548713 0.0427422896027565 0.24455682933330536
+0.16943347454071045 0.0008497175294905901 0.0023515275679528713 0.07628139853477478 0.24619238078594208
+0.16674524545669556 0.0010276080574840307 0.004126560874283314 0.11649245023727417 0.1926332712173462
+0.17361700534820557 0.0006395912496373057 0.0020918806549161673 0.05256560444831848 0.20458847284317017
+0.1690244823694229 0.0020718902815133333 0.0019251767080277205 0.04578087851405144 0.23986420035362244
+0.14371049404144287 0.0010797666618600488 0.0015330270398408175 0.030093248933553696 0.21032875776290894
+0.14622175693511963 0.0009700131486169994 0.003684037597849965 0.0590398870408535 0.1961587816476822
+0.1575433611869812 0.0011272687697783113 0.003883642377331853 0.08855138719081879 0.17336265742778778
+0.1634327620267868 0.00041536305798217654 0.0014929681783542037 0.04416574165225029 0.2655237913131714
+0.1819971203804016 0.0008300252957269549 0.003512696363031864 0.05200870335102081 0.2901034355163574
+0.2021234929561615 0.0003652494924608618 0.0027925949543714523 0.048429571092128754 0.24646197259426117
+0.16704100370407104 0.0012301066890358925 0.001135653001256287 0.10270405560731888 0.1919012814760208
+0.14312300086021423 0.0018898282432928681 0.0015970688546076417 0.05751384422183037 0.18954460322856903
+0.16214518249034882 0.000956431613303721 0.0010288101620972157 0.046634823083877563 0.26630884408950806
+0.15823772549629211 0.00033862784039229155 0.001205340144224465 0.04853476956486702 0.16892163455486298
+0.16373878717422485 0.0005075535736978054 0.0012518311850726604 0.05383491888642311 0.2127019613981247
+0.191579669713974 0.0006452276720665395 0.0020341831259429455 0.04811794310808182 0.2877417802810669
+0.17672470211982727 0.0006463851314038038 0.003312682034447789 0.06012604385614395 0.27018290758132935
+0.18738813698291779 0.0006973001873120666 0.003816433483734727 0.06012788414955139 0.28878360986709595
+0.1799134910106659 0.0009070358937606215 0.0029132808558642864 0.08788737654685974 0.2245272696018219
+0.1683150827884674 0.0009697688510641456 0.0012033635284751654 0.08156456053256989 0.23093438148498535
+0.17249566316604614 0.0004486480902414769 0.0013347856001928449 0.07619800418615341 0.2281958907842636
+0.18496151268482208 0.0008222814649343491 0.0014481095131486654 0.04701991379261017 0.2750914990901947
+0.16307061910629272 0.0016556355403736234 0.0019127103732898831 0.048010047525167465 0.2348792552947998
+0.17409539222717285 0.0009483139729127288 0.0017219935543835163 0.057023946195840836 0.22449737787246704
+0.15979497134685516 0.0005175701808184385 0.0018636673921719193 0.04481226205825806 0.21881666779518127
+0.15315446257591248 0.0008768909610807896 0.003017977811396122 0.04606960341334343 0.21527816355228424
+0.15812356770038605 0.0007312365341931581 0.002145945094525814 0.05100757256150246 0.1913892775774002
+0.15791693329811096 0.0005228340742178261 0.002141097094863653 0.09553618729114532 0.20171339809894562
+0.17514349520206451 0.0008956915116868913 0.0021986430510878563 0.05287954956293106 0.20393776893615723
+0.17383623123168945 0.0005576324765570462 0.0021251991856843233 0.03908916190266609 0.2527260482311249
+0.17828074097633362 0.00042078676051460207 0.0021999820601195097 0.048339858651161194 0.27485889196395874
+0.13898581266403198 0.0006274259649217129 0.002322838641703129 0.060068223625421524 0.17397449910640717
+0.1618732064962387 0.0006326106376945972 0.0019433065317571163 0.11519742757081985 0.18881866335868835
+0.1631726324558258 0.000415267248172313 0.0019610195886343718 0.047292474657297134 0.2021452933549881
+0.16926753520965576 0.0005675987922586501 0.0015510994708165526 0.04173080623149872 0.23893488943576813
+0.17886868119239807 0.0015595339937135577 0.0013810587115585804 0.047333620488643646 0.26936110854148865
+0.17227064073085785 0.0011768314288929105 0.002701418474316597 0.05735056847333908 0.2390354722738266
+0.16640432178974152 0.00034148679696954787 0.002323754597455263 0.07206690311431885 0.212007537484169
+0.17398516833782196 0.0019366444321349263 0.003053448162972927 0.06547445803880692 0.25435370206832886
+0.15684103965759277 0.0015266467817127705 0.002558598294854164 0.04833235219120979 0.24075616896152496
+0.17963019013404846 0.0003877962299156934 0.0012960080057382584 0.06160568445920944 0.2236662209033966
+0.1624518632888794 0.0023016545455902815 0.003254478331655264 0.08401177823543549 0.23371821641921997
+0.15558354556560516 0.0017316576559096575 0.003509696340188384 0.06304429471492767 0.2584799826145172
+0.13906750082969666 0.0003881235606968403 0.0013826959766447544 0.05537142977118492 0.17426052689552307
+0.1580631285905838 0.002474731532856822 0.0017688212683424354 0.05382456257939339 0.2174566686153412
+0.16324010491371155 0.0019546092953532934 0.001788342255167663 0.0431087426841259 0.217402383685112
+0.1318984031677246 0.00036760009243153036 0.0014027772704139352 0.047829415649175644 0.16874423623085022
+0.1581895649433136 0.0017995511880144477 0.002654885407537222 0.059712015092372894 0.18669290840625763
+0.17654913663864136 0.0014188190689310431 0.002130122622475028 0.054393064230680466 0.21306036412715912
+0.1494680643081665 0.0004770078230649233 0.0013262900756672025 0.07688035815954208 0.2154129296541214
+0.18375460803508759 0.0009486713679507375 0.0018839366966858506 0.042277123779058456 0.3047982454299927
+0.17072293162345886 0.0009277832577936351 0.002521331189200282 0.06965417414903641 0.21506530046463013
+0.20292611420154572 0.0006409861380234361 0.0023059165105223656 0.08682003617286682 0.2109648883342743
+0.14389613270759583 0.0004888843977823853 0.00223900331184268 0.04716220498085022 0.18776428699493408
+0.1597304791212082 0.0007071232539601624 0.0020307053346186876 0.04611440747976303 0.2729873061180115
+0.15105846524238586 0.0020378571934998035 0.0024742160458117723 0.04296865314245224 0.23257113993167877
+0.1569090485572815 0.0018598465248942375 0.004305914975702763 0.0781608298420906 0.15991611778736115
+0.12680789828300476 0.0005162961315363646 0.002223449293524027 0.058488879352808 0.155093252658844
+0.17941108345985413 0.0021100167650729418 0.004528388846665621 0.11327793449163437 0.21581192314624786
+0.18986621499061584 0.0009901680750772357 0.0024828705936670303 0.03352738171815872 0.31935104727745056
+0.14662718772888184 0.001123686321079731 0.026798130944371223 0.07530786097049713 0.15203537046909332
+0.1916559487581253 0.0020305279176682234 0.00454539293423295 0.06770686060190201 0.2577401101589203
+0.1497810333967209 0.000522274465765804 0.0014936148654669523 0.09281333535909653 0.2044350951910019
+0.1946190446615219 0.0013161514652892947 0.002362700179219246 0.09989703446626663 0.29892122745513916
+0.20424792170524597 0.0012794742360711098 0.0022095725871622562 0.08265399187803268 0.30767810344696045
+0.15775498747825623 0.000414436828577891 0.0015739103546366096 0.09467942267656326 0.19835522770881653
+0.1661386489868164 0.0009338976233266294 0.0016264189034700394 0.08574492484331131 0.22397388517856598
+0.1587103307247162 0.0005747202085331082 0.0009595506708137691 0.05180230364203453 0.20512060821056366
+0.14634551107883453 0.0003862583835143596 0.0013036048039793968 0.04486555606126785 0.2103407233953476
+0.16577781736850739 0.0003389486519154161 0.001123114605434239 0.04826575145125389 0.24790771305561066
+0.15467101335525513 0.00043146591633558273 0.0026541310362517834 0.07443197071552277 0.18894273042678833
+0.1728450059890747 0.000984414597041905 0.00195568916387856 0.05959020182490349 0.15390704572200775
+0.15021201968193054 0.0018065777840092778 0.0011970059713348746 0.05193084850907326 0.2227957546710968
+0.151614710688591 0.001353962579742074 0.0014235565904527903 0.03652488440275192 0.22727613151073456
+0.13175275921821594 0.0005123220034874976 0.0022536427713930607 0.05056532844901085 0.17676442861557007
+0.15163549780845642 0.0012145666405558586 0.004507614299654961 0.07066943496465683 0.20534682273864746
+0.15296632051467896 0.0005512736388482153 0.0019976000767201185 0.061427269130945206 0.22510911524295807
+0.15526266396045685 0.000564537534955889 0.0035833960864692926 0.04334300756454468 0.15954798460006714
+0.172522634267807 0.0004474467714317143 0.004590777214616537 0.04601473733782768 0.2196095734834671
+0.17476791143417358 0.0007822113111615181 0.0015710418811067939 0.047782719135284424 0.3086400628089905
+0.16362713277339935 0.0029676235280930996 0.002639045240357518 0.07334987819194794 0.23617307841777802
+0.14718490839004517 0.0013045038795098662 0.0024685319513082504 0.06754640489816666 0.19301065802574158
+0.17264673113822937 0.00059405283536762 0.0018680953653529286 0.08480358123779297 0.2523769736289978
+0.1526392698287964 0.0018734588520601392 0.0030922105070203543 0.03612867742776871 0.18752729892730713
+0.1588914543390274 0.00046556771849282086 0.0010241082636639476 0.04410563409328461 0.15465711057186127
+0.15514691174030304 0.0008606142946518958 0.0033623953349888325 0.0514269582927227 0.23957818746566772
+0.18909981846809387 0.0010871689300984144 0.004918997175991535 0.059978317469358444 0.35763660073280334
+0.12903687357902527 0.0002787273551803082 0.0029927340801805258 0.082845538854599 0.16788072884082794
+0.16616114974021912 0.002098285360261798 0.0017016157507896423 0.05678053945302963 0.14740592241287231
+0.16961613297462463 0.002532724756747484 0.0030645979568362236 0.04888530820608139 0.17342674732208252
+0.16887640953063965 0.00025603483663871884 0.0009412786457687616 0.07012268155813217 0.24233901500701904
+0.15350601077079773 0.0036271000280976295 0.004830698948353529 0.06008720025420189 0.21971361339092255
+0.1598324328660965 0.003155237063765526 0.007129158359020948 0.07810600847005844 0.22230494022369385
+0.1551290601491928 0.00038958765799179673 0.0018972961697727442 0.070749431848526 0.260637491941452
+0.1352955400943756 0.004484422504901886 0.004702538717538118 0.058457620441913605 0.166852205991745
+0.16952672600746155 0.0034927839878946543 0.005803138017654419 0.11849774420261383 0.15779484808444977
+0.16408199071884155 0.0007853485876694322 0.0013852034462615848 0.045805059373378754 0.18651151657104492
+0.1562364548444748 0.0061783152632415295 0.0036767187993973494 0.05223577097058296 0.2691049575805664
+0.13935242593288422 0.006017198320478201 0.007199026644229889 0.06134718284010887 0.23455892503261566
+0.17889150977134705 0.000329762464389205 0.0022052086424082518 0.09975980222225189 0.2464635670185089
+0.13723137974739075 0.006219363305717707 0.004472925793379545 0.04416237398982048 0.1491539180278778
+0.16149695217609406 0.007358028087764978 0.0062202466651797295 0.08560386300086975 0.1856038123369217
+0.1486910879611969 0.0004083063977304846 0.001204907544888556 0.037506353110075 0.1774713546037674
+0.14983627200126648 0.00523750577121973 0.006020719185471535 0.05088939145207405 0.2191755473613739
+0.1606920063495636 0.005891459062695503 0.006862862966954708 0.06681745499372482 0.2262401580810547
+0.14992745220661163 0.0004000262706540525 0.0016030120896175504 0.049995094537734985 0.15726742148399353
+0.1435297429561615 0.004103587009012699 0.004301922861486673 0.041541021317243576 0.18069273233413696
+0.16231992840766907 0.0044767283834517 0.004525869153439999 0.04066098481416702 0.24101074039936066
+0.15393829345703125 0.00041363490163348615 0.0022163891699165106 0.0681207925081253 0.24169503152370453
+0.16942910850048065 0.0027028501499444246 0.0029038412030786276 0.08625727891921997 0.23570501804351807
+0.15077859163284302 0.0023653667885810137 0.001425171853043139 0.04453275725245476 0.2100742608308792
+0.17000679671764374 0.0008361487998627126 0.0020514714997261763 0.04929715022444725 0.1898730844259262
+0.1532590240240097 0.00033014165819622576 0.001724952831864357 0.0585675947368145 0.16382035613059998
+0.14060549437999725 0.00033327520941384137 0.0014935809886083007 0.048434361815452576 0.20093493163585663
+0.16268599033355713 0.0004386796208564192 0.003749568248167634 0.07379232347011566 0.28547361493110657
+0.14112310111522675 0.0009305051644332707 0.002308325609192252 0.06818539649248123 0.17837074398994446
+0.1445302516222 0.002167308237403631 0.0017796024912968278 0.04370490834116936 0.1412135362625122
+0.16337043046951294 0.0007726454641669989 0.0015465859323740005 0.034403253346681595 0.18565338850021362
+0.16270580887794495 0.0012268305290490389 0.001403254922479391 0.039598945528268814 0.20610666275024414
+0.1922343671321869 0.0038467401172965765 0.006295326631516218 0.06057798117399216 0.32801753282546997
+0.1527608335018158 0.0011225028429180384 0.005525296553969383 0.09309563785791397 0.2273412048816681
+0.17077308893203735 0.0024028499610722065 0.0017348038963973522 0.08085701614618301 0.27540096640586853
+0.1775236576795578 0.006460960488766432 0.006840928923338652 0.06725990027189255 0.2002018541097641
+0.15015339851379395 0.0014283066848292947 0.004422585479915142 0.05241886153817177 0.15961088240146637
+0.16223958134651184 0.0030843077693134546 0.0011755378218367696 0.057299114763736725 0.23330096900463104
+0.15313397347927094 0.010873244144022465 0.005770276300609112 0.047928184270858765 0.20471662282943726
+0.14515426754951477 0.004624699708074331 0.004311594646424055 0.04919048771262169 0.20432546734809875
+0.1435403972864151 0.0010973748285323381 0.0014156574616208673 0.04470186308026314 0.18176887929439545
+0.13489922881126404 0.008451336994767189 0.00392106780782342 0.04127688333392143 0.1480904072523117
+0.12982504069805145 0.006277484819293022 0.0025301864370703697 0.0412493497133255 0.1602799892425537
+0.14130431413650513 0.0005116951069794595 0.0020903695840388536 0.07259978353977203 0.1521623134613037
+0.15450623631477356 0.003019308438524604 0.00564387533813715 0.04316128417849541 0.22737115621566772
+0.173423171043396 0.003168130526319146 0.00422097509726882 0.03997186943888664 0.28441721200942993
+0.1375156044960022 0.0005561906145885587 0.0012374621583148837 0.03926244378089905 0.14505645632743835
+0.15928232669830322 0.0016100568464025855 0.002929174341261387 0.03593529388308525 0.16371777653694153
+0.17514559626579285 0.002212506951764226 0.0043991259299218655 0.06048309803009033 0.22315458953380585
+0.15020184218883514 0.0004066238470841199 0.0025501460768282413 0.09289488941431046 0.231758251786232
+0.14887180924415588 0.0012228215346112847 0.001973924692720175 0.06445755064487457 0.19988873600959778
+0.14837156236171722 0.0023693169932812452 0.0014533160720020533 0.059103041887283325 0.14973308145999908
+0.15694409608840942 0.001372158178128302 0.0014452786417677999 0.04890747368335724 0.21632565557956696
+0.1532418131828308 0.00047967818682082 0.0011642653262242675 0.04167646914720535 0.16796690225601196
+0.13709579408168793 0.0003736827929969877 0.0013533140299841762 0.06328168511390686 0.1358177363872528
+0.15692245960235596 0.0005516766686923802 0.0022586206905543804 0.060904715210199356 0.23362423479557037
+0.1792842596769333 0.001007421175017953 0.0018334692576900125 0.049715135246515274 0.30053481459617615
+0.13241475820541382 0.000938287703320384 0.0017483568517491221 0.05603837966918945 0.1493358165025711
+0.1315031349658966 0.0006164010846987367 0.0012893318198621273 0.03273141756653786 0.1719702184200287
+0.14048385620117188 0.0003598732000682503 0.0014730875845998526 0.02938900515437126 0.2336091548204422
+0.151702880859375 0.000658639008179307 0.002354271477088332 0.04072168096899986 0.2159184217453003
+0.1390669345855713 0.00048533070366829634 0.002107231877744198 0.05749179795384407 0.13142302632331848
+0.17600169777870178 0.00043896265560761094 0.0014659189619123936 0.07845544070005417 0.23097600042819977
+0.1631738245487213 0.0008128836052492261 0.0021471513900905848 0.042094212025403976 0.2106504887342453
+0.1572681963443756 0.00031829086947254837 0.0018025871831923723 0.06116040050983429 0.16633839905261993
+0.1782337874174118 0.0007676022360101342 0.0013082495424896479 0.03884333372116089 0.24136340618133545
+0.15754428505897522 0.0014763438375666738 0.0021805171854794025 0.04531513527035713 0.2392236739397049
+0.13733622431755066 0.0007688174373470247 0.0021255495958030224 0.09211482852697372 0.18173731863498688
+0.14983084797859192 0.0005441855173557997 0.0019379484001547098 0.0563204400241375 0.19509951770305634
+0.16020235419273376 0.0012111031683161855 0.0027437410317361355 0.033654116094112396 0.2243872731924057
+0.13254374265670776 0.0005597960553131998 0.001469098380766809 0.04672309011220932 0.16488116979599
+0.13999593257904053 0.0007133744074963033 0.001888337661512196 0.05289006233215332 0.21121978759765625
+0.15377122163772583 0.00089616933837533 0.0019791768863797188 0.04673745483160019 0.2184809446334839
+0.14081135392189026 0.00032466492848470807 0.001039163675159216 0.032430749386548996 0.2135213315486908
+0.16289404034614563 0.0005659068119712174 0.0012484622420743108 0.056075569242239 0.21527709066867828
+0.1558804214000702 0.0003675309126265347 0.001206700806505978 0.042887721210718155 0.20155881345272064
+0.12988126277923584 0.00035609875340014696 0.001285349833779037 0.04883992299437523 0.16739927232265472
+0.1386055052280426 0.0003205161774531007 0.0013411921681836247 0.049241889268159866 0.182221919298172
+0.16815054416656494 0.0003229090361855924 0.0014364244416356087 0.04103907197713852 0.261505126953125
+0.14712195098400116 0.00035601778654381633 0.0012965978821739554 0.05954157933592796 0.20463253557682037
+0.1364709585905075 0.0002842747198883444 0.0012516515562310815 0.04808954522013664 0.19431553781032562
+0.14168867468833923 0.0003475989797152579 0.0010914038866758347 0.04272863641381264 0.18722878396511078
+0.17814818024635315 0.00025633000768721104 0.0014530961634591222 0.05321907252073288 0.2534421682357788
+0.1587577909231186 0.0003075535350944847 0.001374793704599142 0.04503469169139862 0.22645440697669983
+0.12162443995475769 0.00031915842555463314 0.0011455458588898182 0.05341692268848419 0.10266143828630447
+0.14006879925727844 0.00041939341463148594 0.0010717349359765649 0.04072866588830948 0.23494304716587067
+0.13607259094715118 0.00041167609742842615 0.0012447880581021309 0.06566429883241653 0.22569549083709717
+0.16782726347446442 0.0003474665863905102 0.0015306900022551417 0.06104803457856178 0.17296236753463745
+0.15671813488006592 0.0004870931734330952 0.0014527764869853854 0.06680942326784134 0.20684832334518433
+0.15727345645427704 0.0005113267106935382 0.001120457542128861 0.03590996563434601 0.2596651017665863
+0.1741461157798767 0.0003043313918169588 0.0013956474140286446 0.03997671231627464 0.2760288417339325
+0.12669430673122406 0.0008180487784557045 0.00189877487719059 0.04521171376109123 0.12231519818305969
+0.15519165992736816 0.00044377928134053946 0.0014565277379006147 0.047632262110710144 0.20208893716335297
+0.1540023386478424 0.00033442824496887624 0.0012955291895195842 0.028491713106632233 0.1929573267698288
+0.15520602464675903 0.0004585845454130322 0.0010857413290068507 0.03731254115700722 0.23319502174854279
+0.13225851953029633 0.0003600680793169886 0.0021470149513334036 0.05331755802035332 0.22454094886779785
+0.11787709593772888 0.00036627970985136926 0.002198632573708892 0.0765647441148758 0.15850868821144104
+0.14340399205684662 0.000587295857258141 0.0015749390004202724 0.05364936962723732 0.15616364777088165
+0.12322661280632019 0.0005458885570988059 0.001958948327228427 0.030294209718704224 0.1576138585805893
+0.15868282318115234 0.0003674672916531563 0.0014694692799821496 0.02829233929514885 0.19652706384658813
+0.14376091957092285 0.0016672341153025627 0.0016081067733466625 0.06536541879177094 0.24882076680660248
+0.1285514235496521 0.0013806922361254692 0.0017354178708046675 0.06212041154503822 0.12122135609388351
+0.1391715705394745 0.0003141839406453073 0.0018624610966071486 0.061534054577350616 0.1864292025566101
+0.14851143956184387 0.0013042379869148135 0.004098041914403439 0.03508942946791649 0.1990068107843399
+0.16672655940055847 0.00036251009441912174 0.001880313386209309 0.03908179700374603 0.23475360870361328
+0.13900154829025269 0.0011024390114471316 0.002118067815899849 0.0915357917547226 0.16130763292312622
+0.1465548872947693 0.0014426918933168054 0.0027563830371946096 0.05711502209305763 0.18701669573783875
+0.12373367697000504 0.00041343882912769914 0.0012884200550615788 0.03728554770350456 0.14541226625442505
+0.1564638316631317 0.0012892024824395776 0.0016544501995667815 0.03337695822119713 0.21578344702720642
+0.14761817455291748 0.0008437092183157802 0.001067137811332941 0.03970316797494888 0.2106439769268036
+0.14470526576042175 0.0003083993215113878 0.0018344774143770337 0.0555570013821125 0.21766851842403412
+0.1359788179397583 0.0007126752170734107 0.002316639292985201 0.05404197797179222 0.16694393754005432
+0.13140080869197845 0.00034909340320155025 0.00120355689432472 0.04574067145586014 0.1583787351846695
+0.16053029894828796 0.00044561861432157457 0.001198198297061026 0.0362892784178257 0.22877748310565948
+0.14430266618728638 0.0003860687429551035 0.0013353422982618213 0.035216324031353 0.19542966783046722
+0.15068349242210388 0.0003754953795578331 0.0018032421357929707 0.04849040508270264 0.1906093806028366
+0.14919477701187134 0.0003838149714283645 0.002392357215285301 0.06008230522274971 0.24203065037727356
+0.11588239669799805 0.0004147608997300267 0.001469496637582779 0.050936952233314514 0.1357032060623169
+0.12280139327049255 0.0007961212540976703 0.0019667905289679766 0.04106300324201584 0.15783430635929108
+0.1378190815448761 0.000439566996647045 0.0012964988127350807 0.046088699251413345 0.15888109803199768
+0.12855640053749084 0.0005213965778239071 0.0013733092928305268 0.042770981788635254 0.16104425489902496
+0.1432773470878601 0.0010483170626685023 0.002121453871950507 0.04436963051557541 0.25609010457992554
+0.14924629032611847 0.000497993198223412 0.0022079150658100843 0.05743510276079178 0.2330770045518875
+0.11143217980861664 0.0007712788647040725 0.00209377845749259 0.059479013085365295 0.09193599969148636
+0.12901972234249115 0.002016764134168625 0.003425597446039319 0.03761924430727959 0.15069721639156342
+0.13521629571914673 0.0006655256729573011 0.002897716825827956 0.04065017029643059 0.2159140557050705
+0.156062051653862 0.0013784102629870176 0.001268055522814393 0.048706576228141785 0.22203250229358673
+0.13558223843574524 0.003166313748806715 0.0026618288829922676 0.09716441482305527 0.15860646963119507
+0.14741969108581543 0.0007188470335677266 0.0013541376683861017 0.05048319697380066 0.21498198807239532
+0.1455378234386444 0.0007563526160083711 0.0028628893196582794 0.045901816338300705 0.1685793548822403
+0.12639759480953217 0.0011615118710324168 0.01630440726876259 0.03996247053146362 0.1717960685491562
+0.14419960975646973 0.00032225006725639105 0.0008575097308494151 0.03580210357904434 0.15349557995796204
+0.1413620114326477 0.001338232192210853 0.002373812021687627 0.046798959374427795 0.2060064673423767
+0.16152822971343994 0.0006541719194501638 0.0023341404739767313 0.036735039204359055 0.2946029603481293
+0.13905848562717438 0.00044677333789877594 0.0011723259231075644 0.0398908406496048 0.21620333194732666
+0.1612270325422287 0.00156329944729805 0.0009251617593690753 0.04880028963088989 0.16765372455120087
+0.13232390582561493 0.0015035083051770926 0.0008775537135079503 0.050393491983413696 0.188592791557312
+0.1423148363828659 0.00045049897744320333 0.0009269100846722722 0.051181990653276443 0.18226368725299835
+0.14275527000427246 0.000653824012260884 0.0014178442070260644 0.0954805463552475 0.21085317432880402
+0.14552909135818481 0.0010204734280705452 0.0015264726243913174 0.043442390859127045 0.20761537551879883
+0.1304255723953247 0.0004381639009807259 0.0011168112978339195 0.0773896872997284 0.17933130264282227
+0.1418417990207672 0.0003797932877205312 0.001483539817854762 0.04605044797062874 0.1336015909910202
+0.12973561882972717 0.0005701172631233931 0.001882776035927236 0.039110489189624786 0.11467818170785904
+0.13821682333946228 0.00034426344791427255 0.0016585220582783222 0.03816862404346466 0.22033871710300446
+0.16708889603614807 0.0006885736365802586 0.0020571586210280657 0.07111568003892899 0.3129979968070984
+0.11983956396579742 0.0007623386918567121 0.0022388806100934744 0.07255113124847412 0.1463753581047058
+0.16593442857265472 0.0002781577641144395 0.0011035868665203452 0.04295515641570091 0.20562633872032166
+0.16998183727264404 0.0007775778649374843 0.0021197248715907335 0.03926905244588852 0.23341530561447144
+0.14023059606552124 0.00034696428338065743 0.0013002138584852219 0.036323051899671555 0.1713910549879074
+0.15005724132061005 0.0006288861040957272 0.0016395257553085685 0.04201078414916992 0.2522372901439667
+0.1366344839334488 0.0009537941659800708 0.002424312988296151 0.06810761243104935 0.24390697479248047
+0.14321258664131165 0.00026234128745272756 0.0010796518763527274 0.05996130779385567 0.17313481867313385
+0.1224987804889679 0.00106474244967103 0.0021719392389059067 0.046309251338243484 0.13180118799209595
+0.1317654550075531 0.0007333591929636896 0.001883764867670834 0.06379780918359756 0.1100732609629631
+0.14599385857582092 0.0004456535680219531 0.000995903741568327 0.04454910382628441 0.20717723667621613
+0.13664062321186066 0.0023530065082013607 0.004251458682119846 0.04682613164186478 0.21041859686374664
+0.13926109671592712 0.0008096357341855764 0.003953038714826107 0.05926733464002609 0.20754900574684143
+0.1324797421693802 0.0014387795235961676 0.0011831950396299362 0.0411248616874218 0.19008196890354156
+0.1518629491329193 0.00421943562105298 0.0029936213977634907 0.0365300215780735 0.2028747946023941
+0.13857056200504303 0.0009878476848825812 0.0013911731075495481 0.053020600229501724 0.14926360547542572
+0.1311812847852707 0.0015208310214802623 0.0016299077542498708 0.03346122428774834 0.139219731092453
+0.13535043597221375 0.0040877205319702625 0.003663428360596299 0.03412460535764694 0.16219964623451233
+0.13163109123706818 0.0012475934345275164 0.0023449293803423643 0.03653452545404434 0.20783743262290955
+0.14548996090888977 0.0009440633002668619 0.0015850710915401578 0.03618815168738365 0.2661103904247284
+0.15230746567249298 0.0031874666456133127 0.002545149065554142 0.056357041001319885 0.24639280140399933
+0.14952751994132996 0.0017822820227593184 0.002430454595014453 0.0667184591293335 0.11679597198963165
+0.16173258423805237 0.00028178736101835966 0.0018544510239735246 0.054731301963329315 0.1297978013753891
+0.14939327538013458 0.0020916173234581947 0.0010081542422994971 0.0303987804800272 0.23192071914672852
+0.14130768179893494 0.005029952619224787 0.0017922292463481426 0.047523632645606995 0.2333318144083023
+0.11881498992443085 0.0025200245436280966 0.003693651407957077 0.05795549601316452 0.13137704133987427
+0.15016044676303864 0.000818660540971905 0.0016999151557683945 0.06580580770969391 0.1766045093536377
+0.13597600162029266 0.006363054271787405 0.004697887226939201 0.044317495077848434 0.14268375933170319
+0.11630672961473465 0.003946896176785231 0.00397774763405323 0.03752976655960083 0.16956692934036255
+0.1654287576675415 0.0006591047276742756 0.0017687476938590407 0.07258832454681396 0.2867911756038666
+0.13821235299110413 0.006367742549628019 0.009367396123707294 0.07408686727285385 0.18871605396270752
+0.14491719007492065 0.0022857110016047955 0.00366595177911222 0.0538182333111763 0.15266592800617218
+0.13868604600429535 0.0014107237802818418 0.002779661677777767 0.03725244849920273 0.15777340531349182
+0.13731852173805237 0.004402197897434235 0.0067799827083945274 0.035566311329603195 0.17079763114452362
+0.10999158024787903 0.0004479198541957885 0.0011320124613121152 0.04427776858210564 0.16796275973320007
+0.13838334381580353 0.002623133361339569 0.004127927124500275 0.07759249955415726 0.21543559432029724
+0.11148229986429214 0.003144323593005538 0.004487687721848488 0.0663648396730423 0.12397131323814392
+0.12747938930988312 0.00020592502551153302 0.0007943973760120571 0.04328705742955208 0.15175192058086395
+0.10924297571182251 0.0018875610549002886 0.0024549830704927444 0.02772848680615425 0.1154736876487732
+0.12552092969417572 0.0009531615069136024 0.0010370148811489344 0.04205074533820152 0.11400436609983444
+0.12748947739601135 0.00035718586877919734 0.004975144751369953 0.06727829575538635 0.18745185434818268
+0.13339537382125854 0.0005328509141691029 0.005266556050628424 0.06477448344230652 0.2130078822374344
+0.11611270904541016 0.000474109168862924 0.0011790404096245766 0.04391266405582428 0.14058536291122437
+0.13844998180866241 0.0013086176477372646 0.002817355329170823 0.033364735543727875 0.16926252841949463
+0.13465452194213867 0.0002839606604538858 0.0019584596157073975 0.06407076120376587 0.16432246565818787
+0.14809900522232056 0.0012181730708107352 0.0011981932912021875 0.05455082654953003 0.1666633039712906
+0.11670595407485962 0.002113677328452468 0.0009100814932025969 0.051833488047122955 0.1570468693971634
+0.14773546159267426 0.0006912917597219348 0.0011854702606797218 0.04226174205541611 0.2571536600589752
+0.13463208079338074 0.00025010789977386594 0.0018982070032507181 0.057183168828487396 0.17428651452064514
+0.12398618459701538 0.0002685851650312543 0.0011702978517860174 0.06111064925789833 0.1309177577495575
+0.15449285507202148 0.00036207199445925653 0.0007987324497662485 0.060860343277454376 0.18085041642189026
+0.1326766312122345 0.000354411662556231 0.001344462507404387 0.04129846394062042 0.2095596343278885
+0.1215556412935257 0.0002991096116602421 0.0026150811463594437 0.05063547194004059 0.18241089582443237
+0.12997449934482574 0.0005411570309661329 0.0027426299639046192 0.044830016791820526 0.20807135105133057
+0.13490624725818634 0.0014373388839885592 0.0013262430438771844 0.0702943354845047 0.17221370339393616
+0.13176442682743073 0.00171189964748919 0.0010213855421170592 0.03839733451604843 0.13698066771030426
+0.14201247692108154 0.0006787599995732307 0.0010631504701450467 0.03366408869624138 0.1956740766763687
+0.13173240423202515 0.0006501636235043406 0.0015822931891307235 0.05135487765073776 0.17756272852420807
+0.12201909720897675 0.0015263869427144527 0.0018288405844941735 0.041804831475019455 0.1663890928030014
+0.14252600073814392 0.0006915428093634546 0.0015811165794730186 0.08652402460575104 0.2026306837797165
+0.13472768664360046 0.00041886416147463024 0.00247043464332819 0.03808988258242607 0.17586907744407654
+0.11087219417095184 0.0005926082376390696 0.003221881343051791 0.03564377874135971 0.14194926619529724
+0.1495429426431656 0.00035871152067556977 0.0020164859015494585 0.05336989462375641 0.1916651725769043
+0.14317524433135986 0.001225736690685153 0.0010794271947816014 0.04346635565161705 0.16498300433158875
+0.1263301521539688 0.0014416433405131102 0.0013353974791243672 0.0461624376475811 0.13628517091274261
+0.09547098726034164 0.0004404334758874029 0.0008806065306998789 0.04056798294186592 0.11602769047021866
+0.15672893822193146 0.000731747830286622 0.0012954723788425326 0.03029157593846321 0.29207053780555725
+0.1307658851146698 0.0007657024543732405 0.001169770141132176 0.055337127298116684 0.21868927776813507
+0.13154560327529907 0.010760526172816753 0.0015284609980881214 0.08923525363206863 0.1398523598909378
+0.14568597078323364 0.00037987082032486796 0.0009417812689207494 0.04192483425140381 0.15178894996643066
+0.12473706901073456 0.000285505288047716 0.0009106163051910698 0.03494979813694954 0.11116049438714981
+0.13565665483474731 0.0006820731214247644 0.0008508904138579965 0.03294847533106804 0.1954387128353119
+0.10083691030740738 0.0011885445564985275 0.002969219349324703 0.04017956927418709 0.13660414516925812
+0.1533668041229248 0.00040355997043661773 0.0031707361340522766 0.06607848405838013 0.21898001432418823
+0.1274152249097824 0.0019657141529023647 0.0013102511875331402 0.0527571476995945 0.16309158504009247
+0.13821569085121155 0.003503338899463415 0.004043943248689175 0.03424706310033798 0.12167073786258698
+0.14595648646354675 0.0005866693099960685 0.0022652950137853622 0.030247163027524948 0.19339431822299957
+0.15877887606620789 0.0025987126864492893 0.0014448680449277163 0.036863137036561966 0.24069003760814667
+0.12352196872234344 0.0062450445257127285 0.004379448480904102 0.05720817670226097 0.1454155594110489
+0.14005011320114136 0.001343295443803072 0.002221380127593875 0.06333727389574051 0.1795578896999359
+0.13337823748588562 0.0018048587953671813 0.003075971035286784 0.03526376187801361 0.19517682492733002
+0.14258086681365967 0.0038633677177131176 0.003807684639468789 0.028964072465896606 0.1841547191143036
+0.12056492269039154 0.0006192519795149565 0.0017213618848472834 0.041850052773952484 0.16108104586601257
+0.11675454676151276 0.0012732358882203698 0.0048257471062242985 0.0653834268450737 0.11983779072761536
+0.1258617788553238 0.0013576530618593097 0.0029690039809793234 0.03864411637187004 0.19602501392364502
+0.10222397744655609 0.00030690908897668123 0.0011836127378046513 0.02901139296591282 0.13248221576213837
+0.13544830679893494 0.0008345226524397731 0.001967775635421276 0.027140799909830093 0.21463975310325623
+0.12421859055757523 0.00041705096373334527 0.0009819488041102886 0.04042922332882881 0.17934808135032654
+0.11632205545902252 0.00031530196429230273 0.0013118952047079802 0.06409582495689392 0.12475232034921646
+0.12522774934768677 0.0003544828505255282 0.0011406003031879663 0.045376770198345184 0.1478768289089203
+0.12489122897386551 0.00029311463003978133 0.001244824263267219 0.025852475315332413 0.14584743976593018
+0.12419231235980988 0.0004913723096251488 0.000932242430280894 0.0314442403614521 0.17363883554935455
+0.11230109632015228 0.0007483939989469945 0.0018869859632104635 0.05022703856229782 0.172075554728508
+0.10422175377607346 0.00022209688904695213 0.0018359888345003128 0.1064145565032959 0.11091684550046921
+0.12028281390666962 0.0012006420874968171 0.001507409499026835 0.04469609260559082 0.11854091286659241
+0.13381436467170715 0.0016261364798992872 0.0027966273482888937 0.0358886793255806 0.1458921879529953
+0.13612255454063416 0.00026040629018098116 0.0008080361294560134 0.0406528078019619 0.19652386009693146
+0.13424356281757355 0.0029018670320510864 0.004031907767057419 0.04880501702427864 0.18095460534095764
+0.12774676084518433 0.001913469168357551 0.0048713237047195435 0.06890679150819778 0.14906516671180725
+0.12540900707244873 0.00075073930202052 0.0009358702227473259 0.044245555996894836 0.17498023808002472
+0.14082276821136475 0.004000440705567598 0.003650550963357091 0.033873528242111206 0.20528815686702728
+0.12094024568796158 0.001166084548458457 0.0014908408047631383 0.06411758810281754 0.1220608651638031
+0.13819611072540283 0.001230973401106894 0.001909456797875464 0.04784087464213371 0.15676483511924744
+0.11023320257663727 0.002995390212163329 0.0039811204187572 0.047552190721035004 0.1605679839849472
+0.1557430624961853 0.0007064415840432048 0.001343066105619073 0.03667512908577919 0.22886861860752106
+0.10822023451328278 0.001001917407847941 0.0013161045499145985 0.03516639769077301 0.13245166838169098
+0.12843388319015503 0.0020969328470528126 0.002331511117517948 0.04168250784277916 0.15936250984668732
+0.1311430186033249 0.0005428873118944466 0.001098266919143498 0.03979033604264259 0.1256319135427475
+0.11242355406284332 0.0006814829539507627 0.0009465118055231869 0.037111252546310425 0.128267303109169
+0.14738942682743073 0.001563685480505228 0.0014352165162563324 0.06049657240509987 0.24473512172698975
+0.13777956366539001 0.0010936778271570802 0.0018872724613174796 0.04246644675731659 0.24194175004959106
+0.1223621666431427 0.0003682765818666667 0.00168609747197479 0.06048731878399849 0.14545056223869324
+0.12013198435306549 0.0023119805846363306 0.0019285804592072964 0.05512222647666931 0.1317368745803833
+0.12003033608198166 0.002328595845028758 0.003544585546478629 0.03751889243721962 0.09274213016033173
+0.1221669465303421 0.000274348451057449 0.0012620603665709496 0.04129168018698692 0.12876297533512115
+0.12559190392494202 0.003921143244951963 0.002118214964866638 0.06923245638608932 0.1894112378358841
+0.12659180164337158 0.004980152007192373 0.005868101958185434 0.06212475523352623 0.18860003352165222
+0.14847445487976074 0.00023536973458249122 0.001719284220598638 0.06744541227817535 0.25180062651634216
+0.0994570404291153 0.005170358810573816 0.004120550584048033 0.04598383605480194 0.1082446426153183
+0.15478447079658508 0.006010459735989571 0.008198177441954613 0.03342045098543167 0.11434486508369446
+0.12754637002944946 0.00026874354807659984 0.0022331918589770794 0.02982514724135399 0.1343969851732254
+0.1326504349708557 0.007298049051314592 0.0022037497255951166 0.03209954500198364 0.22290639579296112
+0.13460129499435425 0.010341204702854156 0.006242326460778713 0.04511499032378197 0.2327692061662674
+0.1252232939004898 0.0009107889491133392 0.001950230565853417 0.05196387320756912 0.18308962881565094
+0.15882617235183716 0.005322739481925964 0.003981407731771469 0.04276357963681221 0.15006208419799805
+0.1531912386417389 0.009431390091776848 0.00634384760633111 0.07841774821281433 0.1517602652311325
+0.14078113436698914 0.000867582974024117 0.001003870158456266 0.029587935656309128 0.24041242897510529
+0.1336766481399536 0.004177918657660484 0.010120866820216179 0.05901815742254257 0.18496301770210266
+0.12035669386386871 0.005467814859002829 0.012629068456590176 0.05608988180756569 0.1694083958864212
+0.12034569680690765 0.0002897411468438804 0.0024749482981860638 0.03195609897375107 0.1734074503183365
+0.12397817522287369 0.005271072033792734 0.004260636866092682 0.0363040417432785 0.12373880296945572
+0.12714995443820953 0.005120403133332729 0.0052546123042702675 0.053951866924762726 0.13916724920272827
+0.14800125360488892 0.00022838922450318933 0.0017474759370088577 0.03514842689037323 0.1825825572013855
+0.10873226076364517 0.004417299758642912 0.003452964359894395 0.0648384690284729 0.15127776563167572
+0.12463469803333282 0.0044289082288742065 0.0026022223755717278 0.05467594042420387 0.204998180270195
+0.14840592443943024 0.0003692202444653958 0.0011525199515745044 0.04551234841346741 0.25626930594444275
+0.1298523247241974 0.0016153478063642979 0.003811961505562067 0.040302496403455734 0.08198683708906174
+0.1337255984544754 0.0020545409061014652 0.004088586661964655 0.03872709348797798 0.1419951617717743
+0.14466333389282227 0.00025171530432999134 0.0012081696186214685 0.04750378057360649 0.20510198175907135
+0.13113519549369812 0.0020249926019459963 0.0019543017260730267 0.03388741612434387 0.22652731835842133
+0.14303013682365417 0.0027190984692424536 0.00357028073631227 0.039540402591228485 0.2226027250289917
+0.11359536647796631 0.00036867475137114525 0.0012756335781887174 0.051623329520225525 0.10850498080253601
+0.11131767183542252 0.002106273313984275 0.0014698655577376485 0.03612339869141579 0.09837793558835983
+0.13097374141216278 0.0031505804508924484 0.0014958461979404092 0.07638859003782272 0.17143826186656952
+0.10177243500947952 0.00048448258894495666 0.0011307939421385527 0.045916926115751266 0.15237943828105927
+0.11272715032100677 0.0008761111530475318 0.004543565213680267 0.058013517409563065 0.18265406787395477
+0.11328837275505066 0.0008109475020319223 0.0019577716011554003 0.050507742911577225 0.18217939138412476
+0.12208457291126251 0.00028836526325903833 0.0008959004771895707 0.039228226989507675 0.1358804702758789
+0.11500370502471924 0.0007182956323958933 0.0027787014842033386 0.0605609267950058 0.1165798231959343
+0.1366175413131714 0.0003020793665200472 0.0012273943284526467 0.031159203499555588 0.1546446532011032
+0.10835645347833633 0.0012732887407764792 0.0014021220849826932 0.04153646528720856 0.1519368290901184
+0.13551977276802063 0.001338653964921832 0.0025347836781293154 0.05565715953707695 0.22622422873973846
+0.12937305867671967 0.00029540032846853137 0.0014288700185716152 0.0493028461933136 0.18986405432224274
+0.1241728663444519 0.001377395004965365 0.002116698771715164 0.03992069512605667 0.12937572598457336
+0.1339903473854065 0.0011800267966464162 0.0022739972919225693 0.032896753400564194 0.15254607796669006
+0.11537877470254898 0.00040129516855813563 0.0007313204114325345 0.03712649270892143 0.1377154141664505
+0.13010400533676147 0.0020075759384781122 0.0029272788669914007 0.03978997841477394 0.2038145214319229
+0.1471422016620636 0.0013751444639638066 0.0030751849990338087 0.05677267163991928 0.19750849902629852
+0.09759503602981567 0.0004384205967653543 0.0010914535960182548 0.06424761563539505 0.06560077518224716
+0.11924008280038834 0.0025620029773563147 0.0017907872097566724 0.04191737622022629 0.13536690175533295
+0.11425529420375824 0.0014967508614063263 0.0010524949757382274 0.042237818241119385 0.15366604924201965
+0.11345174163579941 0.00035019416827708483 0.0019963665399700403 0.037989385426044464 0.16313913464546204
+0.12917855381965637 0.0014454307965934277 0.004336359910666943 0.06595969945192337 0.1497987061738968
+0.13045194745063782 0.00035386401577852666 0.0013887669192627072 0.0749378353357315 0.1473664492368698
+0.11881193518638611 0.0008123904117383063 0.0019321972504258156 0.022423801943659782 0.17983366549015045
+0.1106506958603859 0.0006333505734801292 0.0019026834052056074 0.02297794446349144 0.18009187281131744
+0.11897534877061844 0.00036385050043463707 0.001435564598068595 0.05679620802402496 0.18405388295650482
+0.10539974272251129 0.000857690756674856 0.0017151408828794956 0.06315749883651733 0.0957276001572609
+0.1287880688905716 0.0002543067676015198 0.0012382043059915304 0.03351869061589241 0.1528235673904419
+0.11333197355270386 0.0002733874134719372 0.002777932910248637 0.027531614527106285 0.14620347321033478
+0.11112206429243088 0.008701429702341557 0.001071175909601152 0.027482597157359123 0.12533006072044373
+0.10287188738584518 0.0010855946457013488 0.002516185399144888 0.05416902154684067 0.12153209000825882
+0.11909861862659454 0.0005258202436380088 0.0031739422120153904 0.0682201087474823 0.18359284102916718
+0.11332117021083832 0.0008870859746821225 0.0014554521767422557 0.08246692270040512 0.15852099657058716
+0.1526545137166977 0.002329554408788681 0.0017432600725442171 0.05054338276386261 0.2511833906173706
+0.11984990537166595 0.0008259150199592113 0.0011398263741284609 0.044357944279909134 0.12005198746919632
+0.12851838767528534 0.00038834023871459067 0.0018836137605831027 0.09920600801706314 0.11037272214889526
+0.10854442417621613 0.000833439058624208 0.001235519303008914 0.03451170772314072 0.11513125896453857
+0.10650445520877838 0.00047757732681930065 0.0007883539656177163 0.026979615911841393 0.14076802134513855
+0.13020764291286469 0.00021625444060191512 0.0010199418757110834 0.03192654624581337 0.19536146521568298
+0.12409428507089615 0.00021858107356820256 0.0013888988178223372 0.0428621880710125 0.2037959098815918
+0.13312843441963196 0.0002225942153017968 0.0012284518452361226 0.044492680579423904 0.17870590090751648
+0.12855525314807892 0.00020238719298504293 0.0013346424093469977 0.03767511621117592 0.16963784396648407
+0.1270434856414795 0.00023303061607293785 0.001091444748453796 0.029316755011677742 0.13028283417224884
+0.12914881110191345 0.0010096670594066381 0.0008810185827314854 0.029112640768289566 0.16241328418254852
+0.10364952683448792 0.0015647569671273232 0.0028011600952595472 0.033319294452667236 0.15696316957473755
+0.12022703886032104 0.00038554586353711784 0.004875303246080875 0.06586042791604996 0.1974797397851944
+0.11653675884008408 0.0019437780138105154 0.002859198721125722 0.09177852421998978 0.14493009448051453
+0.11779289692640305 0.006063312757760286 0.0026694554835557938 0.03601934760808945 0.1219477728009224
+0.13543231785297394 0.002637560246512294 0.0024218475446105003 0.028478823602199554 0.1747322827577591
+0.11290988326072693 0.0009178576874546707 0.0011191165540367365 0.025282414630055428 0.140602245926857
+0.13313066959381104 0.00653867470100522 0.005301614757627249 0.04661089554429054 0.18075987696647644
+0.11186628043651581 0.0031515685841441154 0.0030068671330809593 0.04739854857325554 0.14384105801582336
+0.12547793984413147 0.000586793408729136 0.001883206539787352 0.042086753994226456 0.1593744307756424
+0.1311400830745697 0.004582538269460201 0.0075226835906505585 0.036332253366708755 0.14750416576862335
+0.10674192011356354 0.0016178505029529333 0.00404041912406683 0.02969054877758026 0.15547111630439758
+0.12054100632667542 0.0015262191882357001 0.0012436799006536603 0.03586965054273605 0.16557340323925018
+0.10085606575012207 0.005815595388412476 0.0053849853575229645 0.052253179252147675 0.09954187273979187
+0.14354564249515533 0.0015698226634413004 0.0017467461293563247 0.03430379554629326 0.19585657119750977
+0.09884459525346756 0.0009991753613576293 0.0013707675971090794 0.030959030613303185 0.11731235682964325
+0.1075257658958435 0.0029456184711307287 0.0017386868130415678 0.026188598945736885 0.15525157749652863
+0.10613490641117096 0.0012024311581626534 0.001616329071111977 0.05164124071598053 0.16181136667728424
+0.11515817046165466 0.0003617640759330243 0.003216330660507083 0.05770521238446236 0.10862671583890915
+0.15348860621452332 0.0002747188264038414 0.0011622768361121416 0.029282346367836 0.2078692764043808
+0.1269940882921219 0.00026798484032042325 0.0006606073002330959 0.021913370117545128 0.19353927671909332
+0.11326038837432861 0.00035354914143681526 0.00071821449091658 0.06912095099687576 0.15105511248111725
+0.11108041554689407 0.0004776871937792748 0.0009482577443122864 0.04772288724780083 0.12474008649587631
+0.1119060218334198 0.00021017660037614405 0.0010412660194560885 0.04867212474346161 0.17207087576389313
+0.10149254649877548 0.0004204070719424635 0.0015770528698340058 0.07548791170120239 0.14144787192344666
+0.12476525455713272 0.00039558979915454984 0.0020148754119873047 0.037099774926900864 0.16490508615970612
+0.10951396077871323 0.00023137028620112687 0.0010337928542867303 0.03768595680594444 0.09865929186344147
+0.12911929190158844 0.0009116881992667913 0.001075586536899209 0.04736025631427765 0.12431634217500687
+0.10855914652347565 0.0007730926154181361 0.0010784454643726349 0.03900396078824997 0.14088697731494904
+0.10615470260381699 0.00020529043104033917 0.0018964045448228717 0.03334664925932884 0.1684587150812149
+0.10957199335098267 0.0003788339381571859 0.0016961663495749235 0.08517827093601227 0.13578291237354279
+0.10957460105419159 0.0010083633242174983 0.0013775493716821074 0.04027320444583893 0.1425473690032959
+0.10214971005916595 0.0008399036014452577 0.0013365609338507056 0.04594000056385994 0.07668431103229523
+0.11814653873443604 0.00020509996102191508 0.000812402053270489 0.030010420829057693 0.15767212212085724
+0.11403584480285645 0.0008288199896924198 0.0015265692491084337 0.02858065813779831 0.17847180366516113
+0.10857801139354706 0.0013242157874628901 0.0022657655645161867 0.03730422630906105 0.15767815709114075
+0.11253337562084198 0.00030409140163101256 0.001987309427931905 0.044857874512672424 0.17467781901359558
+0.12019580602645874 0.0010752237867563963 0.0013197859516367316 0.04706418514251709 0.19727729260921478
+0.12559916079044342 0.0018028623890131712 0.0026824590750038624 0.036993108689785004 0.11937374621629715
+0.1167660728096962 0.0003197039768565446 0.0017657761927694082 0.025641435757279396 0.14886574447155
+0.13719719648361206 0.0017136543756350875 0.0011198262218385935 0.02750292979180813 0.2087530642747879
+0.1024094894528389 0.0038352857809513807 0.0032591896597296 0.039626941084861755 0.13601915538311005
+0.12364274263381958 0.0008522769785486162 0.0023800120688974857 0.049578167498111725 0.16596630215644836
+0.1257864236831665 0.0014268838567659259 0.0014282468473538756 0.041781604290008545 0.12372508645057678
+0.10942468792200089 0.004188756458461285 0.0024827923625707626 0.02930864877998829 0.14707310497760773
+0.10224217176437378 0.0014886683784425259 0.0017676933202892542 0.0475410595536232 0.13536618649959564
+0.10145071148872375 0.0005976983811706305 0.002810828620567918 0.04709309712052345 0.09659472107887268
+0.11411568522453308 0.0014139065751805902 0.002402317011728883 0.03716396167874336 0.130027636885643
+0.11794520914554596 0.0006508917431347072 0.001077932771295309 0.03311561420559883 0.1557416319847107
+0.11492633819580078 0.0002722400240600109 0.0016003375640138984 0.028815699741244316 0.15440857410430908
+0.10478575527667999 0.0003970287798438221 0.002020888961851597 0.038286611437797546 0.13149823248386383
+0.11586112529039383 0.0002911539631895721 0.0020099764224141836 0.05714696645736694 0.09515228867530823
+0.13067197799682617 0.0002657449513208121 0.0011807093396782875 0.03784346580505371 0.1811595857143402
+0.12224064767360687 0.0004796934954356402 0.0008637876017019153 0.027757365256547928 0.1834239512681961
+0.11324694752693176 0.0010244937147945166 0.001070819329470396 0.03839315474033356 0.16376638412475586
+0.13070684671401978 0.0007120308582670987 0.001121465815231204 0.07483583688735962 0.14698997139930725
+0.10597370564937592 0.0002795595792122185 0.0008494649082422256 0.04287683591246605 0.11195569485425949
+0.11918200552463531 0.0009653487941250205 0.000950838322751224 0.03525274246931076 0.09537339210510254
+0.0980580747127533 0.0005103547009639442 0.0008147038170136511 0.03545659780502319 0.13073541224002838
+0.09691216796636581 0.00042761501390486956 0.003194018267095089 0.052561771124601364 0.1480124443769455
+0.10344285517930984 0.00045725045492872596 0.005159473046660423 0.055852390825748444 0.14524811506271362
+0.12197960913181305 0.0005955513333901763 0.0020245187915861607 0.04473366588354111 0.1172061413526535
+0.11709921061992645 0.002591583179309964 0.0014966463204473257 0.028502753004431725 0.09297094494104385
+0.12331034988164902 0.0013398358132690191 0.0018244730308651924 0.028042178601026535 0.16809850931167603
+0.12110099196434021 0.0006836940883658826 0.0011703402269631624 0.050537701696157455 0.18018838763237
+0.1185784637928009 0.00314504187554121 0.0034345658496022224 0.07116778939962387 0.1294366717338562
+0.10409620404243469 0.0009630229906179011 0.0013286847388371825 0.07061317563056946 0.12211479991674423
+0.11291255056858063 0.0005508934264071286 0.004570250399410725 0.027305930852890015 0.15463416278362274
+0.09770708531141281 0.0007268024492077529 0.004974302835762501 0.029654256999492645 0.1018998846411705
+0.12391814589500427 0.0005538996192626655 0.001310753053985536 0.03571369871497154 0.1627233326435089
+0.10764993727207184 0.002441591117531061 0.0014631652738898993 0.04294077679514885 0.1029309406876564
+0.1194029301404953 0.0011814030585810542 0.0013110917061567307 0.03356458991765976 0.17993302643299103
+0.1102810800075531 0.00040312603232450783 0.0007374563138000667 0.02822612226009369 0.14490589499473572
+0.10204020142555237 0.0017152277287095785 0.000981003511697054 0.03203611448407173 0.1451176255941391
+0.1218530535697937 0.0013585565611720085 0.0010836673900485039 0.03818278759717941 0.1813705861568451
+0.10880928486585617 0.00041530458838678896 0.0018305579433217645 0.043502327054739 0.154218852519989
+0.11966609954833984 0.00032361707417294383 0.0019451123662292957 0.04198214411735535 0.18246641755104065
+0.10710183531045914 0.00021476020629052073 0.0005840593948960304 0.033426038920879364 0.12698189914226532
+0.10981988906860352 0.00020448298892006278 0.0006881825975142419 0.028032949194312096 0.11601642519235611
+0.1103149950504303 0.00022994281607680023 0.0007089483551681042 0.03029257245361805 0.1398629993200302
+0.10141683369874954 0.000305720983305946 0.0007305560284294188 0.03146814554929733 0.1287045180797577
+0.12044085562229156 0.0003881728043779731 0.0008299690089188516 0.032414425164461136 0.22017289698123932
+0.1045217216014862 0.00019309586787130684 0.0006877144332975149 0.03640468791127205 0.15162919461727142
+0.08604985475540161 0.00024208684044424444 0.000761717208661139 0.03411022201180458 0.09621509909629822
+0.11790281534194946 0.00029727432411164045 0.0009347453014925122 0.029377879574894905 0.15310192108154297
+0.11004585027694702 0.00021324947010725737 0.0008191733504645526 0.038204729557037354 0.12154029309749603
+0.10281984508037567 0.00043778924737125635 0.0011949347099289298 0.03467091917991638 0.10650143772363663
+0.08384823054075241 0.0002512831997592002 0.0010803586337715387 0.0368884839117527 0.07676846534013748
+0.11031252145767212 0.00033895840169861913 0.0008657966391183436 0.03753173351287842 0.14950057864189148
+0.1161646842956543 0.0005591866793110967 0.0010785210179165006 0.03389466926455498 0.17737893760204315
+0.09142018854618073 0.000300271698506549 0.0012940175365656614 0.04832851514220238 0.11924754828214645
+0.09553378075361252 0.0001639670954318717 0.00118943490087986 0.05139697715640068 0.09484703838825226
+0.10888975858688354 0.00015550010721199214 0.0021152349654585123 0.029700133949518204 0.1513689160346985
+0.11112384498119354 0.0002610501542221755 0.0010357669088989496 0.027354469522833824 0.12715399265289307
+0.11009855568408966 0.0005890494794584811 0.0009337993687950075 0.026404669508337975 0.15567505359649658
+0.11934874206781387 0.0008360219071619213 0.0008900373941287398 0.04077369347214699 0.1772540956735611
+0.10492302477359772 0.00023860821966081858 0.0009618665208108723 0.048568688333034515 0.15465407073497772
+0.10323141515254974 0.0004871810378972441 0.001311683445237577 0.032347243279218674 0.11685331165790558
+0.09401391446590424 0.0007407785742543638 0.001438300940208137 0.04962582886219025 0.11180052161216736
+0.11167173087596893 0.0001562442339491099 0.0006641809595748782 0.030937349423766136 0.12283909320831299
+0.11005966365337372 0.0011252503609284759 0.002755633555352688 0.0345655158162117 0.13821923732757568
+0.10621383786201477 0.0007006920059211552 0.0035603351425379515 0.03991541638970375 0.15833377838134766
+0.10510173439979553 0.000633567338809371 0.0011971441563218832 0.03684312105178833 0.14638271927833557
+0.09743218123912811 0.0026872064918279648 0.0012028476921841502 0.031248142942786217 0.12891055643558502
+0.10960255563259125 0.0017318099271506071 0.0015215467428788543 0.03141726180911064 0.15946608781814575
+0.10550622642040253 0.0001952113670995459 0.0012931692181155086 0.03774157539010048 0.1660676747560501
+0.11006100475788116 0.0014041662216186523 0.00221132580190897 0.052212659269571304 0.11486080288887024
+0.11092814058065414 0.0012155334698036313 0.0011117247631773353 0.04370787367224693 0.12242009490728378
+0.1045260950922966 0.00031998861231841147 0.0016584174009039998 0.029007237404584885 0.12930583953857422
+0.126299649477005 0.00037060389877296984 0.0016179345548152924 0.05358980596065521 0.1491970419883728
+0.1032635048031807 0.00024762333487160504 0.0013507072580978274 0.04432782158255577 0.09856691211462021
+0.10482611507177353 0.00029289210215210915 0.0012674693716689944 0.05116024985909462 0.14414510130882263
+0.09848213195800781 0.0003217233170289546 0.0011024177074432373 0.03240125626325607 0.13178318738937378
+0.09517945349216461 0.0002547514741308987 0.0008415683987550437 0.027610542252659798 0.10507824271917343
+0.09943369030952454 0.0003002499870490283 0.0005491306656040251 0.029562951996922493 0.09344103187322617
+0.092983677983284 0.0004791012906935066 0.001079542562365532 0.024867556989192963 0.09660609066486359
+0.12245164811611176 0.0003238012141082436 0.0016845668433234096 0.03385559841990471 0.1716345101594925
+0.12483556568622589 0.0005286639207042754 0.001296633156016469 0.04453186318278313 0.19762656092643738
+0.10063754767179489 0.0015648265834897757 0.0010695065138861537 0.04211843013763428 0.1333891898393631
+0.10859093070030212 0.0011958113173022866 0.001539379358291626 0.05577763915061951 0.1285044252872467
+0.10835118591785431 0.00023486714053433388 0.0006059848237782717 0.028888070955872536 0.1081910952925682
+0.10741351544857025 0.001967746764421463 0.0007205419824458659 0.03119128756225109 0.11273176223039627
+0.11112233996391296 0.0025000933092087507 0.0012761485995724797 0.03126412257552147 0.16352908313274384
+0.10223430395126343 0.00038473683525808156 0.0007837212760932744 0.04490305110812187 0.1480112075805664
+0.10001649707555771 0.0010689824121072888 0.0015264545800164342 0.04917272925376892 0.1361626535654068
+0.09140922129154205 0.001806767308153212 0.0019366167252883315 0.03501204773783684 0.07008746266365051
+0.10756395012140274 0.00025036834995262325 0.0006327240262180567 0.024859409779310226 0.12525349855422974
+0.07466943562030792 0.001237477408722043 0.002252717036753893 0.028041711077094078 0.10164561122655869
+0.09924936294555664 0.0016823504120111465 0.0034169612918049097 0.04185336083173752 0.14537972211837769
+0.09398390352725983 0.00018236885080114007 0.0010580177186056972 0.040861792862415314 0.10515151172876358
+0.11572305113077164 0.0022304926533252 0.0015191843267530203 0.044924236834049225 0.16141481697559357
+0.11278637498617172 0.001998340943828225 0.0022137558553367853 0.030588053166866302 0.13899829983711243
+0.11317034065723419 0.00026655220426619053 0.0007366887293756008 0.04109374061226845 0.13159331679344177
+0.10881171375513077 0.002173895016312599 0.001481909304857254 0.030244380235671997 0.13530349731445312
+0.08872281759977341 0.0023666508495807648 0.001756776007823646 0.0467107817530632 0.10673124343156815
+0.10228052735328674 0.0003353731008246541 0.001018948620185256 0.033812567591667175 0.1502925455570221
+0.1017688438296318 0.0010927659459412098 0.0012662316439673305 0.03878629580140114 0.13877850770950317
+0.10976216197013855 0.0012306071585044265 0.0012972912518307567 0.04508034512400627 0.15292538702487946
+0.11977830529212952 0.00031536800088360906 0.0007438150933012366 0.04392319545149803 0.1399867832660675
+0.09587673097848892 0.0004046020039822906 0.000966921157669276 0.02799428254365921 0.10376502573490143
+0.10143300890922546 0.0007396508590318263 0.00070878432597965 0.02417858876287937 0.09503374993801117
+0.11096621304750443 0.0005258918390609324 0.000819226261228323 0.025421518832445145 0.19927546381950378
+0.10281822085380554 0.00017030125309247524 0.001609215629287064 0.04552779719233513 0.15318836271762848
+0.10691545158624649 0.0006516650319099426 0.0009605712257325649 0.0457150861620903 0.13658978044986725
+0.08846404403448105 0.0015323528787121177 0.0019603471737354994 0.029948383569717407 0.09577172994613647
+0.1069435253739357 0.00032042371458373964 0.0017400466604158282 0.025331147015094757 0.10894539207220078
+0.09968072175979614 0.0015881865983828902 0.0007414935389533639 0.02176019363105297 0.15111102163791656
+0.11367671191692352 0.0037881971802562475 0.003275121795013547 0.06970618665218353 0.1458045244216919
+0.07981876283884048 0.0008402016828767955 0.0027665868401527405 0.051426488906145096 0.09073624014854431
+0.09162194281816483 0.0019118241034448147 0.0016918256878852844 0.0769125372171402 0.1030314639210701
+0.10163882374763489 0.0049762362614274025 0.003491911105811596 0.02712419629096985 0.1325882077217102
+0.10039964318275452 0.0009559233440086246 0.001136911567300558 0.04310457408428192 0.08915481716394424
+0.095282182097435 0.0015341817634180188 0.002979973563924432 0.03854567930102348 0.08604421466588974
+0.10580748319625854 0.0030881166458129883 0.004384869243949652 0.03702074661850929 0.1789560616016388
+0.08199173957109451 0.0005136067629791796 0.0012914189137518406 0.04212390258908272 0.08571464568376541
+0.08054463565349579 0.001397405401803553 0.0026707169599831104 0.03709276393055916 0.10707800090312958
+0.09974852204322815 0.0021481907460838556 0.0042250449769198895 0.03974200412631035 0.1304505467414856
+0.10156272351741791 0.00029387083486653864 0.001656249281950295 0.03385251760482788 0.12751413881778717
+0.09122313559055328 0.0019311244832351804 0.000516196305397898 0.04875869303941727 0.08627945184707642
+0.08963271230459213 0.003253704635426402 0.0009448531200177968 0.025985639542341232 0.11554328352212906
+0.1029217317700386 0.0009734776103869081 0.0008383903186768293 0.0344439372420311 0.1607079654932022
+0.10124145448207855 0.0003384935262147337 0.0006995961302891374 0.03572401776909828 0.14414025843143463
+0.09036736190319061 0.0019903378561139107 0.000955000170506537 0.041499681770801544 0.07043104618787766
+0.1195528656244278 0.0016623461851850152 0.0008189482032321393 0.030275622382760048 0.13404324650764465
+0.10071240365505219 0.00018738843209575862 0.0010549926664680243 0.028797224164009094 0.13018551468849182
+0.10168494284152985 0.0012010299833491445 0.0036500052083283663 0.030883412808179855 0.16966043412685394
+0.12706458568572998 0.0008832316962070763 0.003563635516911745 0.04779399186372757 0.21367520093917847
+0.1103782206773758 0.00039605345227755606 0.0013459170004352927 0.06295867264270782 0.1253819316625595
+0.12706568837165833 0.002544965362176299 0.002372842514887452 0.03137405961751938 0.11613298952579498
+0.08754048496484756 0.0012944495538249612 0.0030080610886216164 0.029188502579927444 0.09088049083948135
+0.116855189204216 0.0008357411716133356 0.001257662195712328 0.0379667691886425 0.1407937854528427
+0.09558067470788956 0.0041416604071855545 0.004449216183274984 0.03890078887343407 0.14081130921840668
+0.10402214527130127 0.0017238148720934987 0.002857081824913621 0.04547635838389397 0.13423432409763336
+0.08020482957363129 0.0007510441355407238 0.001765471650287509 0.030496124178171158 0.08016155660152435
+0.09806445240974426 0.0032214983366429806 0.0037834420800209045 0.030205825343728065 0.11466778814792633
+0.09097572416067123 0.0010715252719819546 0.0017068168381229043 0.03146917372941971 0.06705863773822784
+0.08832299709320068 0.0006199987255968153 0.001214285148307681 0.0561896488070488 0.10175155103206635
+0.10213720053434372 0.002071957103908062 0.001535085029900074 0.05681513622403145 0.13994979858398438
+0.12316611409187317 0.0008014798513613641 0.000727717881090939 0.031043436378240585 0.18985655903816223
+0.10650435835123062 0.0002463231794536114 0.0011782936053350568 0.03194788470864296 0.1529081016778946
+0.09624426066875458 0.0010004685027524829 0.002223980613052845 0.05785129964351654 0.07278762012720108
+0.09425754845142365 0.0006151435663923621 0.0025528701953589916 0.06319211423397064 0.08086226880550385
+0.08978527784347534 0.0004794493142981082 0.0006677558412775397 0.04428291320800781 0.09199833124876022
+0.09369789808988571 0.0029773362912237644 0.0020818111952394247 0.027186302468180656 0.14733242988586426
+0.08959519863128662 0.0020733431447297335 0.003970824182033539 0.08455855399370193 0.12505845725536346
+0.0901409238576889 0.0004047600377816707 0.0015462833689525723 0.04714033007621765 0.12665244936943054
+0.10695627331733704 0.00431117694824934 0.0009894213872030377 0.04626995325088501 0.1011892557144165
+0.11069497466087341 0.004014123231172562 0.0010892817517742515 0.022023051977157593 0.15521319210529327
+0.10433126240968704 0.0002450946776662022 0.0018594216089695692 0.025943081825971603 0.1619340032339096
+0.08893271535634995 0.002123628742992878 0.006347445771098137 0.05421002581715584 0.085673026740551
+0.11602874845266342 0.001099743414670229 0.0025113169103860855 0.03476881608366966 0.1526108980178833
+0.06949879974126816 0.0004873168363701552 0.0015976387076079845 0.025470905005931854 0.06895586103200912
+0.10961258411407471 0.0015573996352031827 0.003986796364188194 0.023554442450404167 0.1700069010257721
+0.08939700573682785 0.00023448745196219534 0.0019575341138988733 0.05017991364002228 0.08962196111679077
+0.10687348991632462 0.0012028561905026436 0.0007425333024002612 0.03638176992535591 0.15492156147956848
+0.08169551193714142 0.001934935455210507 0.0008139644633047283 0.034638646990060806 0.09402463585138321
+0.10951404273509979 0.0005967551260255277 0.0008406259585171938 0.029893934726715088 0.15163902938365936
+0.09679817408323288 0.00018597014422994107 0.0011790709104388952 0.02274690754711628 0.13254359364509583
+0.09637731313705444 0.00022669871395919472 0.0007378943264484406 0.029052965342998505 0.13150031864643097
+0.07914434373378754 0.00023519426758866757 0.0013016501907259226 0.04428168013691902 0.11154884099960327
+0.08745388686656952 0.00033620031899772584 0.0010464225197210908 0.03487718850374222 0.10952858626842499
+0.09569522738456726 0.00044690020149573684 0.0007499100174754858 0.04039788991212845 0.10522744059562683
+0.08895575255155563 0.00027925631729885936 0.0008251424878835678 0.04850846156477928 0.11016546189785004
+0.09193447232246399 0.0002479729300830513 0.0017892372561618686 0.02912210114300251 0.1222735196352005
+0.09874016046524048 0.0003502621257212013 0.002232041908428073 0.04642904922366142 0.09198909997940063
+0.09630513191223145 0.00039418714004568756 0.0012235213071107864 0.04290498048067093 0.0992986336350441
+0.09834232926368713 0.0008524467120878398 0.0016533213201910257 0.028589198365807533 0.155478373169899
+0.10874094814062119 0.00029601508867926896 0.0010725012980401516 0.030290400609374046 0.15826940536499023
+0.07509232312440872 0.0005604171892628074 0.0007564894040115178 0.03556598350405693 0.07428443431854248
+0.11295285820960999 0.0010340007720515132 0.0006627255352213979 0.03505526855587959 0.13367612659931183
+0.08715619146823883 0.00032038206700235605 0.0006453784881159663 0.02821427583694458 0.10928656905889511
+0.09881839901208878 0.00015796047227922827 0.0008169145439751446 0.02544526197016239 0.1564977616071701
+0.08021815121173859 0.00016356489504687488 0.0006503481417894363 0.04027537256479263 0.0819111317396164
+0.08812540769577026 0.00018362300761509687 0.0011688278755173087 0.04044526070356369 0.12086997181177139
+0.09948715567588806 0.00043117156019434333 0.0007865724037401378 0.033298593014478683 0.11719661206007004
+0.07892028242349625 0.0006352189229801297 0.0009157155873253942 0.029861165210604668 0.1001756563782692
+0.08286169171333313 0.00029220496071502566 0.0005865877028554678 0.023056652396917343 0.10044103115797043
+0.08354198932647705 0.00028726604068651795 0.0011548068141564727 0.03564765676856041 0.07722929120063782
+0.08469598740339279 0.00027637428138405085 0.0013954840833321214 0.041986290365457535 0.09259027987718582
+0.08599137514829636 0.00017726782243698835 0.0008512369822710752 0.03408735245466232 0.09918688982725143
+0.07239478826522827 0.00022908371465746313 0.0017167042242363095 0.03811051324009895 0.08759752660989761
+0.09567559510469437 0.00020929655875079334 0.0012707008281722665 0.03060976415872574 0.12809808552265167
+0.08566644042730331 0.0007615580689162016 0.0009239249629899859 0.04004717245697975 0.10224906355142593
+0.0903460830450058 0.0007762927561998367 0.001002900185994804 0.03457527235150337 0.08134360611438751
+0.10965295135974884 0.00019006991351488978 0.0005938825779594481 0.033322133123874664 0.08201314508914948
+0.10059680044651031 0.0002873575431294739 0.000530207937117666 0.021551117300987244 0.13866741955280304
+0.10416512191295624 0.00024147499061655253 0.0009293883922509849 0.028597401455044746 0.18022169172763824
+0.09403924643993378 0.0002392615278949961 0.0022717381361871958 0.050063084810972214 0.1420835703611374
+0.0923616886138916 0.0004946911940351129 0.0009074854315258563 0.04312313720583916 0.07940249890089035
+0.10879568755626678 0.00097006990108639 0.0012145937653258443 0.028895413503050804 0.10345307737588882
+0.10646425932645798 0.00019504041119944304 0.0009870209032669663 0.05522657558321953 0.08625204861164093
+0.07656528055667877 0.0013115155743435025 0.0008426827844232321 0.026075655594468117 0.08788219094276428
+0.1054687574505806 0.0018282223027199507 0.002038444159552455 0.045953765511512756 0.15444770455360413
+0.09958682954311371 0.00025546562392264605 0.0008404725813306868 0.048557206988334656 0.1583390235900879
+0.0917174220085144 0.0016605452401563525 0.002445963444188237 0.038980282843112946 0.10155904293060303
+0.11024467647075653 0.0019404650665819645 0.003919287584722042 0.02876121550798416 0.09784841537475586
+0.08587439358234406 0.00018034831737168133 0.0010825073113664985 0.024124760180711746 0.09364453703165054
+0.08517232537269592 0.0033770031295716763 0.0014338343171402812 0.029216011986136436 0.08363182097673416
+0.1057291328907013 0.003356009256094694 0.0029237442649900913 0.038410790264606476 0.17014078795909882
+0.09412002563476562 0.00031751839560456574 0.0016618625959381461 0.04256947338581085 0.1482241302728653
+0.09348352253437042 0.0029219016432762146 0.0018733268370851874 0.04090665280818939 0.11709461361169815
+0.07979937642812729 0.004084685817360878 0.0026405637618154287 0.031207380816340446 0.06697547435760498
+0.09046308696269989 0.0004608709423337132 0.0007991954917088151 0.026349831372499466 0.1256588250398636
+0.0972021073102951 0.0023101256228983402 0.0028644604608416557 0.0383094847202301 0.13176211714744568
+0.10550437867641449 0.0033689141273498535 0.004222358111292124 0.038130082190036774 0.14725267887115479
+0.12013357132673264 0.000323854765156284 0.001103202230297029 0.029767384752631187 0.10196656733751297
+0.07187782227993011 0.0023624028544873 0.0019167778082191944 0.07469328492879868 0.08400025963783264
+0.10228456556797028 0.0023038468789309263 0.0018923967145383358 0.024867450818419456 0.15020515024662018
+0.10877329111099243 0.0002351808361709118 0.001217326964251697 0.039080969989299774 0.14665696024894714
+0.08497969806194305 0.0012308493023738265 0.002624996704980731 0.0541885644197464 0.09153681993484497
+0.0763537585735321 0.0005027882289141417 0.0006354826036840677 0.024092495441436768 0.06352734565734863
+0.10723946988582611 0.00016504299128428102 0.0016508599510416389 0.01805766485631466 0.11038805544376373
+0.07933950424194336 0.00015789404278621078 0.0008040095563046634 0.02256563864648342 0.1236032247543335
+0.07982758432626724 0.0004894845187664032 0.0015618171310052276 0.05357581749558449 0.08823773264884949
+0.09277497977018356 0.00020512753690127283 0.0011625683400779963 0.06235421076416969 0.11920454353094101
+0.09548138827085495 0.000848229683469981 0.0016006098594516516 0.028421513736248016 0.12065690010786057
+0.10455580055713654 0.0006449244683608413 0.0023622941225767136 0.03356155380606651 0.1164180189371109
+0.1041092574596405 0.0006003899616189301 0.0004409722751006484 0.026826729997992516 0.1762809157371521
+0.08835020661354065 0.003051575506106019 0.003711810801178217 0.05737682431936264 0.120809406042099
+0.09837789088487625 0.0007152897887863219 0.0018569175153970718 0.07751736044883728 0.12232330441474915
+0.08103932440280914 0.0017598316771909595 0.0015354595379903913 0.030767980962991714 0.07340533286333084
+0.1086016446352005 0.0033462534192949533 0.0037906989455223083 0.06533266603946686 0.11647640913724899
+0.09342604875564575 0.00018300920783076435 0.0004993138718418777 0.022084485739469528 0.1195632740855217
+0.10512173175811768 0.003501397091895342 0.0066292076371610165 0.045685578137636185 0.14773070812225342
+0.0796375721693039 0.002038680948317051 0.0060309614054858685 0.08590672165155411 0.09594284743070602
+0.0898948535323143 0.0015950746601447463 0.0011044357670471072 0.03134769946336746 0.13353674113750458
+0.11232394725084305 0.005811282899230719 0.006206993944942951 0.027571745216846466 0.14645671844482422
+0.10401728749275208 0.0008550502243451774 0.0020508072338998318 0.04001188650727272 0.13256047666072845
+0.10087066888809204 0.002955770818516612 0.001472729374654591 0.05106223374605179 0.08393829315900803
+0.0703766718506813 0.005412465892732143 0.0019878221210092306 0.03469749167561531 0.07383004575967789
+0.11791357398033142 0.0010306037729606032 0.0004548674332909286 0.03255144879221916 0.18655171990394592
+0.08224983513355255 0.0005946039455011487 0.0012699550716206431 0.03284161165356636 0.10589083284139633
+0.1134299635887146 0.0015990115934982896 0.0009956626454368234 0.05811199173331261 0.0823121964931488
+0.08080317825078964 0.0005888648447580636 0.0006183492951095104 0.04335543513298035 0.0642472431063652
+0.10057690739631653 0.0002657205332070589 0.0011710774851962924 0.02268989384174347 0.1206805408000946
+0.1146896481513977 0.0005893624038435519 0.0020881479140371084 0.020504998043179512 0.21853527426719666
+0.09401264786720276 0.0003757443919312209 0.0026932235341519117 0.041335802525281906 0.1182708814740181
+0.10397186130285263 0.000677136704325676 0.0008432964095845819 0.04168597236275673 0.11698281019926071
+0.08966649323701859 0.0024781806860119104 0.002460357965901494 0.027256889268755913 0.059322379529476166
+0.09323028475046158 0.0008105931337922812 0.0026706827338784933 0.03190527483820915 0.08456751704216003
+0.11416388303041458 0.0014974892837926745 0.0005839979858137667 0.025888515636324883 0.17505718767642975
+0.07740035653114319 0.005415296647697687 0.003616867121309042 0.0438741110265255 0.07817844301462173
+0.10479240119457245 0.0014083473943173885 0.002728450810536742 0.05721269175410271 0.10410120338201523
+0.07468093186616898 0.0018018570262938738 0.010845191776752472 0.05086857080459595 0.0808405950665474
+0.09814280271530151 0.004520772956311703 0.007972877472639084 0.026645829901099205 0.12065857648849487
+0.0795656219124794 0.00035225271130912006 0.0019204311538487673 0.02689412608742714 0.06670945137739182
+0.09782016277313232 0.0031224095728248358 0.002140006050467491 0.035810187458992004 0.08084210008382797
+0.09380412846803665 0.004439228679984808 0.0032326991204172373 0.030100490897893906 0.09230118244886398
+0.08866830915212631 0.0004460815398488194 0.0008746830862946808 0.02354363724589348 0.14144566655158997
+0.09665103256702423 0.0017203708412125707 0.0010633382480591536 0.02992466278374195 0.15556804835796356
+0.09325292706489563 0.0029582607094198465 0.0012874709209427238 0.048724617809057236 0.09893109649419785
+0.08571960031986237 0.00101007882039994 0.0007104739779606462 0.03286002576351166 0.0840643048286438
+0.08439229428768158 0.0003074543783441186 0.0007020363118499517 0.01967649906873703 0.11756003648042679
+0.09841176867485046 0.0018603865755721927 0.0016199905658140779 0.01909666694700718 0.13225117325782776
+0.07561248540878296 0.0015397819224745035 0.0020111682824790478 0.024705206975340843 0.1086534857749939
+0.060958050191402435 0.0002505010925233364 0.0008602471207268536 0.03204150125384331 0.07446637749671936
+0.08533083647489548 0.001957578118890524 0.0009652624721638858 0.03765232861042023 0.08378048986196518
+0.09846177697181702 0.0022948002442717552 0.0014956614468246698 0.049194272607564926 0.13581442832946777
+0.09510177373886108 0.00019813994003925472 0.0004579798842314631 0.026114238426089287 0.12515833973884583
+0.08823849260807037 0.0017588550690561533 0.0013722333824262023 0.025824883952736855 0.12347039580345154
+0.08343042433261871 0.0023034189362078905 0.0018257738556712866 0.04196362569928169 0.0793832466006279
+0.09025120735168457 0.00024612632114440203 0.0006078880396671593 0.041104476898908615 0.10357506573200226
+0.10083647072315216 0.0015083777252584696 0.001969657838344574 0.030102290213108063 0.16027170419692993
+0.08195388317108154 0.001509008463472128 0.0020386886317282915 0.03274159878492355 0.090310238301754
+0.07917390763759613 0.00020540542027447373 0.0007936760666780174 0.03850635513663292 0.08661060780286789
+0.09344744682312012 0.0009478902793489397 0.0009541717008687556 0.03462633863091469 0.1393549144268036
+0.08861677348613739 0.001221851329319179 0.0006426103063859046 0.02793176658451557 0.12157224863767624
+0.11165839433670044 0.00043450103839859366 0.000607604393735528 0.02626793459057808 0.13583338260650635
+0.08020932972431183 0.00028082693461328745 0.001279745134525001 0.03274228423833847 0.10273712873458862
+0.08742819726467133 0.0006548790843226016 0.0013849990209564567 0.031039847061038017 0.09631486982107162
+0.07895830273628235 0.0001924546086229384 0.0007824549684301019 0.026936635375022888 0.10400553047657013
+0.09198042005300522 0.0007054791785776615 0.0007478863117285073 0.027905425056815147 0.1099679172039032
+0.09648257493972778 0.0011757060419768095 0.001768536982126534 0.031190969049930573 0.10853821784257889
+0.09913863241672516 0.00029296710272319615 0.001460059778764844 0.029455451294779778 0.14615581929683685
+0.0892489030957222 0.0008485369617119431 0.000648481713142246 0.03254375979304314 0.14983142912387848
+0.09934019297361374 0.0020627016201615334 0.0006532272091135383 0.03901679813861847 0.11896288394927979
+0.09311951696872711 0.0010045700473710895 0.0005690979887731373 0.03110061213374138 0.11107032746076584
+0.08768576383590698 0.00018744796398095787 0.000817073683720082 0.03372126445174217 0.12216696888208389
+0.08876045048236847 0.0009638220653869212 0.0010362830944359303 0.031581878662109375 0.06823853403329849
+0.09442683309316635 0.000782468356192112 0.0006973561830818653 0.02446451596915722 0.1420847475528717
+0.07569936662912369 0.00023321252956520766 0.0007043253281153738 0.0432337261736393 0.10241254419088364
+0.10885760188102722 0.00032140163239091635 0.0011876539792865515 0.03590792044997215 0.17834463715553284
+0.0896429568529129 0.00045463963760994375 0.0013198753586038947 0.035679738968610764 0.10701773315668106
+0.07885195314884186 0.00021878609550185502 0.001177554135210812 0.028599578887224197 0.05804191529750824
+0.08987803757190704 0.00034253523335792124 0.0006717359065078199 0.020341312512755394 0.12579482793807983
+0.0771079733967781 0.0012962588807567954 0.0008560936548747122 0.025157935917377472 0.07572028040885925
+0.07465756684541702 0.0010333660757169127 0.0017970075132325292 0.03376322239637375 0.0907989889383316
+0.07985249906778336 0.00024570946698077023 0.0008123746374621987 0.03707048296928406 0.09489756077528
+0.09394863992929459 0.002146859187632799 0.0011544438311830163 0.03135540336370468 0.12604156136512756
+0.08393536508083344 0.0016017263988032937 0.001117926207371056 0.03491935133934021 0.0905294418334961
+0.07618613541126251 0.0001711091899778694 0.0007232649368233979 0.0235028937458992 0.09257260710000992
+0.0813896507024765 0.0014910617610439658 0.0018601005431264639 0.03266368433833122 0.10256680846214294
+0.08561338484287262 0.001119401422329247 0.0015888301422819495 0.03657594695687294 0.121432825922966
+0.07362050563097 0.00018002887372858822 0.0009611800778657198 0.04312485083937645 0.09214401990175247
+0.08336085081100464 0.0013016033917665482 0.0024540580343455076 0.0325419120490551 0.09696227312088013
+0.08382795751094818 0.000611022871453315 0.001926170545630157 0.033243950456380844 0.10198140144348145
+0.10321718454360962 0.00047218569670803845 0.0008726944797672331 0.03127329796552658 0.11255019903182983
+0.07506414502859116 0.0018542687175795436 0.0010714124655351043 0.03820423036813736 0.06826092302799225
+0.09521718323230743 0.0010146613931283355 0.0009823148138821125 0.033485498279333115 0.13903871178627014
+0.08458784222602844 0.0001919299829751253 0.0007632699562236667 0.04323764890432358 0.12234213948249817
+0.09528951346874237 0.001085490919649601 0.0010496712056919932 0.03861532732844353 0.1276937872171402
+0.0914255678653717 0.0010630055330693722 0.0009824777953326702 0.03802118077874184 0.10801830887794495
+0.0804627537727356 0.00021967645443510264 0.0006897866260260344 0.03470808267593384 0.09063267707824707
+0.0827709436416626 0.0005568171618506312 0.0007013550493866205 0.025035066530108452 0.11421401053667068
+0.067377470433712 0.0013002584455534816 0.0014151303330436349 0.023293636739253998 0.04757274314761162
+0.09477125853300095 0.0004053551238030195 0.0010458604665473104 0.02714894711971283 0.12309405952692032
+0.08487064391374588 0.00037318223621696234 0.0006244191899895668 0.031130142509937286 0.1082233190536499
+0.08045449107885361 0.0013590243179351091 0.0007788900402374566 0.03280644118785858 0.11715535819530487
+0.08898159861564636 0.000967163301538676 0.0007242302526719868 0.0588105171918869 0.07908795028924942
+0.08269675076007843 0.00013770612713415176 0.0006534041604027152 0.020640140399336815 0.09320074319839478
+0.06657180190086365 0.001115253777243197 0.0007709900964982808 0.020114393904805183 0.0682329386472702
+0.08830910176038742 0.0015351350884884596 0.0014237777795642614 0.025436362251639366 0.14267084002494812
+0.09769038110971451 0.0003360835835337639 0.0012102460023015738 0.04108241945505142 0.1660139113664627
+0.10146887600421906 0.0008661335450597107 0.0015230762073770165 0.07373728603124619 0.12418247014284134
+0.08471657335758209 0.0019452492706477642 0.0034233005717396736 0.0296705961227417 0.06480914354324341
+0.08120495080947876 0.00027101265732198954 0.001709415577352047 0.023366054520010948 0.08565250784158707
+0.09827432036399841 0.0023322890046983957 0.0010766972554847598 0.03088538534939289 0.1318681538105011
+0.08048931509256363 0.004021004773676395 0.004227465018630028 0.038878750056028366 0.08608578890562057
+0.07963777333498001 0.0003693644539453089 0.0017809710698202252 0.03508666530251503 0.10329233109951019
+0.08365599066019058 0.0024262508377432823 0.0014406964182853699 0.03428369760513306 0.10992896556854248
+0.07395487278699875 0.0039283535443246365 0.002003571717068553 0.022838044911623 0.10193949937820435
+0.07615166902542114 0.0005555997486226261 0.0009504936169832945 0.04702916741371155 0.0846182256937027
+0.08322899043560028 0.0010303838644176722 0.0021554178092628717 0.027295108884572983 0.11571106314659119
+0.09029635787010193 0.001918220310471952 0.0022215598728507757 0.03254442662000656 0.08337049186229706
+0.06951218843460083 0.00030474551022052765 0.0004695349489338696 0.0195660088211298 0.06989475339651108
+0.09937013685703278 0.0005872689071111381 0.0010885498486459255 0.021578297019004822 0.14324615895748138
+0.097871795296669 0.0008876764914020896 0.0013004826614633203 0.03449716046452522 0.16546429693698883
+0.09697312861680984 0.00028432192630134523 0.0012551405234262347 0.049645934253931046 0.11441779881715775
+0.08498509228229523 0.00014247838407754898 0.0010178356897085905 0.0295618437230587 0.11910174787044525
+0.08565632998943329 0.0004497290647123009 0.0009338475647382438 0.02033621072769165 0.09107791632413864
+0.08185891807079315 0.001305661047808826 0.00044540848466567695 0.039388611912727356 0.07595048099756241
+0.08962322026491165 0.0010702993022277951 0.0008041682885959744 0.030289167538285255 0.10933572798967361
+0.09571553766727448 0.0001729990472085774 0.0010186334839090705 0.0398225300014019 0.15226872265338898
+0.08688671886920929 0.0010976811172440648 0.0009044695761986077 0.03616204485297203 0.11132784932851791
+0.07178963720798492 0.0017730266554281116 0.001388810109347105 0.03024832345545292 0.08449161797761917
+0.09073224663734436 0.00039599076262675226 0.0005517858662642539 0.021213892847299576 0.07919567823410034
+0.08157377690076828 0.00091721978969872 0.0007649268372915685 0.022206785157322884 0.07547283172607422
+0.08791320770978928 0.0022113597951829433 0.002387802815064788 0.020449291914701462 0.11541617661714554
+0.09250573813915253 0.0006179003394208848 0.001481328741647303 0.035948071628808975 0.12480379641056061
+0.09477987885475159 0.0008479223470203578 0.0011258372105658054 0.04187067970633507 0.132734477519989
+0.09882144629955292 0.0025046085938811302 0.0014465332496911287 0.03458947315812111 0.11469772458076477
+0.09624676406383514 0.0011304639047011733 0.0010917835170403123 0.0419590063393116 0.1259947568178177
+0.08534638583660126 0.000424673140514642 0.0005531690549105406 0.023455332964658737 0.11145048588514328
+0.08848170936107635 0.0023001863155514 0.0016812692629173398 0.027150046080350876 0.1081857904791832
+0.09089066833257675 0.0011624214239418507 0.001145853428170085 0.03062223084270954 0.13699299097061157
+0.08088642358779907 0.00023473345208913088 0.0007845068466849625 0.06609875708818436 0.0942116379737854
+0.09226851165294647 0.0015877934638410807 0.002174847759306431 0.02515079453587532 0.13485927879810333
+0.06659325957298279 0.0006495400448329747 0.0013714825036004186 0.02611226588487625 0.059985704720020294
+0.07722654193639755 0.00039328052662312984 0.0005356912151910365 0.021810555830597878 0.08443819731473923
+0.06726415455341339 0.0016857897862792015 0.0012718156212940812 0.024113379418849945 0.07333848625421524
+0.06520958244800568 0.0008434036863036454 0.0008133229566738009 0.02330493927001953 0.06570233404636383
+0.09273139387369156 0.0002765502140391618 0.000590155425015837 0.022564122453331947 0.07050298899412155
+0.0957346111536026 0.0011249163653701544 0.001133918878622353 0.030704280361533165 0.1325773447751999
+0.08489193022251129 0.0006855737883597612 0.0006977930315770209 0.04199204593896866 0.12268216162919998
+0.08904892206192017 0.00031652124016545713 0.0008200663723982871 0.03403804078698158 0.13703961670398712
+0.09330172091722488 0.0005010532913729548 0.0009743287810124457 0.042498327791690826 0.11940488964319229
+0.08308013528585434 0.00041442576912231743 0.0005379730719141662 0.02757691591978073 0.09234397113323212
+0.09864579141139984 0.00024287700944114476 0.0006214457098394632 0.026455413550138474 0.09452100843191147
+0.06565920263528824 0.00032061594538390636 0.0005378662026487291 0.028545746579766273 0.05168892815709114
+0.08375464379787445 0.000269950891379267 0.0011350599816069007 0.03341124951839447 0.10569164901971817
+0.10088126361370087 0.00026853583403863013 0.0013850658433511853 0.03513374552130699 0.1746908277273178
+0.10323479771614075 0.00037335208617150784 0.0011915141949430108 0.044348035007715225 0.19276200234889984
+0.10020169615745544 0.00033238393370993435 0.000785678974352777 0.038329217582941055 0.09693852812051773
+0.08177541196346283 0.00023571908241137862 0.001290942425839603 0.0233880877494812 0.03900952637195587
+0.09144221246242523 0.0002707716776058078 0.0013163002440705895 0.017632031813263893 0.06744148582220078
+0.0909009799361229 0.001843046280555427 0.0007374568376690149 0.017888521775603294 0.1362202912569046
+0.09886151552200317 0.002308395691215992 0.002573105040937662 0.04020066186785698 0.16587795317173004
+0.10248767584562302 0.00016395087004639208 0.0013926845276728272 0.06440278887748718 0.17825104296207428
+0.08843138813972473 0.0039442977868020535 0.0028063757345080376 0.06596078723669052 0.06585226953029633
+0.0929209291934967 0.004070398397743702 0.0057146488688886166 0.023553337901830673 0.053237657994031906
+0.09472675621509552 0.00036172752152197063 0.0007712461519986391 0.03303602710366249 0.11412409693002701
+0.07649293541908264 0.008198829367756844 0.00563200656324625 0.023343190550804138 0.09152086824178696
+0.08592642843723297 0.005891679786145687 0.007869071327149868 0.06142698973417282 0.11540523171424866
+0.09812991321086884 0.0008951300987973809 0.0015871304785832763 0.050109297037124634 0.1457042545080185
+0.08020588010549545 0.011326690204441547 0.005933595821261406 0.0473555251955986 0.10356106609106064
+0.08089081197977066 0.007223838474601507 0.004933580756187439 0.03825578838586807 0.06277719140052795
+0.07298385351896286 0.0004995340714231133 0.0009560679318383336 0.05471893772482872 0.06262117624282837
+0.07016710937023163 0.009249535389244556 0.006384207401424646 0.0609101839363575 0.05826805531978607
+0.08039496093988419 0.0059851668775081635 0.003604266792535782 0.023677635937929153 0.10068632662296295
+0.0752590000629425 0.00027785563725046813 0.0009778118692338467 0.023750849068164825 0.1188340038061142
+0.08286193013191223 0.005086181219667196 0.004260302055627108 0.030968237668275833 0.11721863597631454
+0.06590099632740021 0.003608680795878172 0.002945838263258338 0.03297335281968117 0.07491683959960938
+0.062317900359630585 0.0001959215587703511 0.0004938475321978331 0.035235874354839325 0.07621430605649948
+0.09475149214267731 0.003201352898031473 0.0014291256666183472 0.047290951013565063 0.1172841265797615
+0.09987464547157288 0.0035762768238782883 0.0018618333851918578 0.019264979287981987 0.1415015012025833
+0.0616784542798996 0.0004478182236198336 0.0008179789292626083 0.02647252194583416 0.06426361948251724
+0.08002650737762451 0.001544808503240347 0.0008944584988057613 0.033424146473407745 0.10498642176389694
+0.06412612646818161 0.003567503299564123 0.0017670157831162214 0.039201509207487106 0.06869861483573914
+0.08470756560564041 0.0010651772608980536 0.0007625356665812433 0.03193693608045578 0.12250912934541702
+0.08713553100824356 0.0007777936989441514 0.001286858576349914 0.020502548664808273 0.10108979046344757
+0.07991902530193329 0.0026324675418436527 0.003047327743843198 0.03396157547831535 0.0673300176858902
+0.08518992364406586 0.000802196969743818 0.0011701375478878617 0.024973856285214424 0.10931418091058731
+0.06318177282810211 0.0006507497164420784 0.001646522548981011 0.02566608227789402 0.0839959904551506
+0.07250285148620605 0.0016411663964390755 0.0023102769628167152 0.03050241433084011 0.09516070783138275
+0.07348397374153137 0.0008401894010603428 0.001932385261170566 0.03664177656173706 0.09822501242160797
+0.09423324465751648 0.000306376488879323 0.0014655324630439281 0.033734824508428574 0.12002865970134735
+0.08470780402421951 0.0011168473865836859 0.0008940365514717996 0.026698879897594452 0.046897754073143005
+0.0904044508934021 0.0010798753937706351 0.00047579754027538 0.01890687085688114 0.08972138911485672
+0.08914045989513397 0.0005508318427018821 0.0005345432437025011 0.0213185902684927 0.10336115211248398
+0.10049138963222504 0.00033813543268479407 0.0010060521308332682 0.03321042284369469 0.172100231051445
+0.07429485023021698 0.0005518082762137055 0.000984941958449781 0.045055754482746124 0.11231130361557007
+0.07748736441135406 0.001294890302233398 0.0010857165325433016 0.03079420141875744 0.10699556022882462
+0.07994276285171509 0.0004626727313734591 0.0011738900793716311 0.019351115450263023 0.06289739161729813
+0.08050845563411713 0.0005945429438725114 0.0005996975232847035 0.020849402993917465 0.04161582887172699
+0.09480655193328857 0.002650856738910079 0.0022214774508029222 0.02776668407022953 0.06941536068916321
+0.06980732083320618 0.0012288190191611648 0.0024098719004541636 0.028018632903695107 0.10024186968803406
+0.09193546324968338 0.0005071369232609868 0.0009111501276493073 0.038556281477212906 0.16507796943187714
+0.07697837054729462 0.0038721601013094187 0.001406978815793991 0.042612653225660324 0.12002051621675491
+0.0701943188905716 0.003383574541658163 0.0019128724234178662 0.03178710862994194 0.06432691961526871
+0.07549631595611572 0.00017775665037333965 0.0006460209842771292 0.026160860434174538 0.06074831634759903
+0.07723499834537506 0.0029561370611190796 0.0013383673503994942 0.02225509285926819 0.07217926532030106
+0.08060853183269501 0.00452444888651371 0.002538074040785432 0.02563285455107689 0.11041277647018433
+0.08902159333229065 0.0007349727093242109 0.0016266416059806943 0.031883783638477325 0.13096651434898376
+0.07423048466444016 0.001708874711766839 0.0023542444687336683 0.04767363891005516 0.05907193198800087
+0.10610312223434448 0.00466383108869195 0.005242896731942892 0.0330708809196949 0.16986294090747833
+0.07289393246173859 0.0007431767880916595 0.0018130785319954157 0.026387330144643784 0.10470439493656158
+0.08998461067676544 0.0021230848506093025 0.0012870966456830502 0.033851414918899536 0.09084266424179077
+0.0867692083120346 0.004474977497011423 0.002513090381398797 0.031090054661035538 0.0843459814786911
+0.07941967248916626 0.000802106165792793 0.0005718134925700724 0.01675129123032093 0.08092666417360306
+0.08997458219528198 0.0006216886104084551 0.0007076514302752912 0.0179691631346941 0.13097506761550903
+0.0871695876121521 0.0016974264290183783 0.000739828625228256 0.039128124713897705 0.144388347864151
+0.08750715106725693 0.001298088813200593 0.0007788795628584921 0.053828947246074677 0.13463681936264038
+0.079637810587883 0.0005301146884448826 0.0004045646928716451 0.027598097920417786 0.08741401135921478
+0.08345286548137665 0.0005437318468466401 0.00045291241258382797 0.018128396943211555 0.09037619829177856
+0.0937592089176178 0.0019008757080882788 0.0008099278784357011 0.019112085923552513 0.08953673392534256
+0.07701876759529114 0.0026306824292987585 0.0023557490203529596 0.028866291046142578 0.08910262584686279
+0.07849642634391785 0.0002544952149037272 0.0013466788223013282 0.041084952652454376 0.10239240527153015
+0.08255164325237274 0.002251786645501852 0.0013436046428978443 0.034728340804576874 0.11898227781057358
+0.0881965309381485 0.004148678854107857 0.002471217419952154 0.030344178900122643 0.13807156682014465
+0.07869705557823181 0.0007626909064128995 0.0017110236221924424 0.030046148225665092 0.09685958176851273
+0.0853215903043747 0.001199211343191564 0.0019692417699843645 0.030684880912303925 0.09334032982587814
+0.06377355754375458 0.003098150249570608 0.0023956538643687963 0.02675064653158188 0.042977094650268555
+0.07284398376941681 0.0010365177877247334 0.0009613037109375 0.01697644591331482 0.06431449949741364
+0.050958480685949326 0.0003065172058995813 0.001714976504445076 0.018694614991545677 0.040008120238780975
+0.07117505371570587 0.0008265612414106727 0.002001759596168995 0.03943263366818428 0.09800487011671066
+0.09772977232933044 0.0009047136991284788 0.0023673984687775373 0.061780814081430435 0.15009234845638275
+0.09635069966316223 0.000615062890574336 0.0012704202672466636 0.044735122472047806 0.15479522943496704
+0.06920812278985977 0.00019188865553587675 0.0012349124299362302 0.02536550536751747 0.0760720893740654
+0.0840572789311409 0.0006991755799390376 0.000724084151443094 0.02167356386780739 0.06275661289691925
+0.0781717449426651 0.0019426228245720267 0.000414613721659407 0.020069614052772522 0.07547438889741898
+0.08561450242996216 0.0015573478303849697 0.0007128426805138588 0.027453897520899773 0.1312633901834488
+0.06472543627023697 0.00017548908363096416 0.0009038374992087483 0.06510201841592789 0.08400280773639679
+0.08774624764919281 0.0015881237341091037 0.0009627199615351856 0.03728866204619408 0.1386537104845047
+0.07333886623382568 0.0026017846539616585 0.0017087266314774752 0.03581717610359192 0.07774297893047333
+0.06743404269218445 0.00035209572524763644 0.0004904283559881151 0.02284999191761017 0.08719903230667114
+0.07516385614871979 0.0016285491874441504 0.002185810124501586 0.02211940474808216 0.06331370770931244
+0.08635343611240387 0.002522510476410389 0.003563333535566926 0.046696972101926804 0.055847570300102234
+0.07536055147647858 0.00037838343996554613 0.0010061616776511073 0.03904275596141815 0.055199168622493744
+0.08707860857248306 0.0014175751712173223 0.0007919730851426721 0.025509273633360863 0.1395028978586197
+0.08519186079502106 0.0022483360953629017 0.0017531048506498337 0.0428440161049366 0.13625574111938477
+0.08947315812110901 0.0012647743569687009 0.0017818877240642905 0.057526227086782455 0.10039176046848297
+0.09273232519626617 0.0003073415136896074 0.0008451142930425704 0.030975010246038437 0.11716252565383911
+0.08088785409927368 0.0007397494046017528 0.0003928776422981173 0.01693098619580269 0.07773956656455994
+0.08018289506435394 0.0022945906966924667 0.0006539773312397301 0.01580023393034935 0.07360576093196869
+0.0702611431479454 0.001923153642565012 0.0010579823283478618 0.022807102650403976 0.07149770855903625
+0.09075330197811127 0.0001874165318440646 0.0009030396468006074 0.042795922607183456 0.12054725736379623
+0.08897083252668381 0.001954245613887906 0.0016031664563342929 0.041029203683137894 0.1488921344280243
+0.07680387794971466 0.002749606966972351 0.002582784043624997 0.03761005774140358 0.10499442368745804
+0.06540121138095856 0.00037536106538027525 0.001063056057319045 0.04057372733950615 0.05290111526846886
+0.07570285350084305 0.0009906121995300055 0.00047671920037828386 0.017987163737416267 0.09631893783807755
+0.09132099896669388 0.003385064424946904 0.0009533897391520441 0.020878244191408157 0.1011817455291748
+0.07957464456558228 0.0019166964339092374 0.0009654651512391865 0.027138726785779 0.062613844871521
+0.07276368141174316 0.00017963830032385886 0.0006567526725120842 0.03240038827061653 0.0937771275639534
+0.08066874742507935 0.0021033084485679865 0.000972610549069941 0.03603982925415039 0.11819565296173096
+0.093548484146595 0.0021312402095645666 0.0008357780170626938 0.03825998678803444 0.15709634125232697
+0.07212729752063751 0.0003399843699298799 0.000669739325530827 0.033994439989328384 0.06478694826364517
+0.059951648116111755 0.0004058411286678165 0.0012280556838959455 0.03249485790729523 0.06851515173912048
+0.0691784918308258 0.0008067243034020066 0.0010294816456735134 0.015790922567248344 0.06499991565942764
+0.07434019446372986 0.00047937041381374 0.0007232497446238995 0.01978607475757599 0.07507017999887466
+0.0805593729019165 0.00012384337605908513 0.0006799025577493012 0.022038675844669342 0.09470748901367188
+0.08010617643594742 0.0004589752934407443 0.0007676758104935288 0.040541134774684906 0.0911082923412323
+0.07421322166919708 0.0008716814918443561 0.0009137197630479932 0.037709567695856094 0.08765146136283875
+0.06461580097675323 0.0005225499626249075 0.0007184636779129505 0.029099389910697937 0.07799845188856125
+0.06857416033744812 0.00023724672792013735 0.00042525684693828225 0.021594388410449028 0.06764191389083862
+0.09112915396690369 0.0011313623981550336 0.000519853550940752 0.02461346611380577 0.09449820965528488
+0.05492892488837242 0.0011404907563701272 0.0005336744943633676 0.029957827180624008 0.03803877905011177
+0.06928203999996185 0.00016738819249439985 0.0005282819038257003 0.026768332347273827 0.07155922055244446
+0.08321031183004379 0.0007018796168267727 0.0015688315033912659 0.038040656596422195 0.09852128475904465
+0.05885813385248184 0.0004950187867507339 0.001018276670947671 0.02310759760439396 0.07254420965909958
+0.05876173824071884 0.00028302244027145207 0.000749057624489069 0.04842156171798706 0.06052647903561592
+0.06781124323606491 0.0005110690835863352 0.0009762686095200479 0.027971280738711357 0.07556038349866867
+0.06133154034614563 0.00030247445101849735 0.000553710968233645 0.02262517437338829 0.07015034556388855
+0.06829115003347397 0.0002516377717256546 0.0008047541487030685 0.027259917929768562 0.05688861757516861
+0.0800846666097641 0.00027853500796481967 0.0007129589794203639 0.02508798986673355 0.11125216633081436
+0.07244323194026947 0.0001820313191274181 0.000697072537150234 0.037374842911958694 0.09699417650699615
+0.08776132762432098 0.00019850405806209892 0.0006128363311290741 0.029300520196557045 0.13093049824237823
+0.07552787661552429 0.00017740031762514263 0.0005348199629224837 0.02405720204114914 0.08064553886651993
+0.07172608375549316 0.00014480979007203132 0.00048108684131875634 0.02393505908548832 0.05193847790360451
+0.06520147621631622 0.00025091777206398547 0.0003963246417697519 0.019647177308797836 0.08073695003986359
+0.07018924504518509 0.0005915170768275857 0.0012004603631794453 0.026359980925917625 0.08961368352174759
+0.07340584695339203 0.0002714346337597817 0.0013188160955905914 0.03940944746136665 0.1013597920536995
+0.07255421578884125 0.0006522870389744639 0.0006590154371224344 0.0337287038564682 0.08350390940904617
+0.08518320322036743 0.0017481433460488915 0.0009825039887800813 0.033400073647499084 0.1211760863661766
+0.08231896907091141 0.0006005022441968322 0.0006492238026112318 0.019126037135720253 0.1164102554321289
+0.07814142107963562 0.0004127232823520899 0.0007738464628346264 0.025799760594964027 0.06690698862075806
+0.0715247392654419 0.0011664588237181306 0.0012097782455384731 0.02363814227283001 0.08465517312288284
+0.06083357706665993 0.00039485745946876705 0.00047672775690443814 0.02936282381415367 0.04949771612882614
+0.07514579594135284 0.0002545581664890051 0.0013007562374696136 0.02115040086209774 0.10747519135475159
+0.06381871551275253 0.00035366826341487467 0.0012104753404855728 0.026938553899526596 0.06348007917404175
+0.09582683444023132 0.00014689916861243546 0.0007363478071056306 0.03673533722758293 0.09072931855916977
+0.06730881333351135 0.00014240470773074776 0.0006307379808276892 0.029869189485907555 0.0398896299302578
+0.08851529657840729 0.00016500195488333702 0.0005587234045378864 0.017447486519813538 0.0957936942577362
+0.060401879251003265 0.00041332573164254427 0.0005622868775390089 0.01680092141032219 0.05983305722475052
+0.07709090411663055 0.0005037939408794045 0.0013676718808710575 0.032918449491262436 0.13341154158115387
+0.0764189064502716 0.00015746770077385008 0.0010438822209835052 0.04917691648006439 0.10553023219108582
+0.06869770586490631 0.0018469158094376326 0.0013309763744473457 0.029833173379302025 0.07719460874795914
+0.0764939934015274 0.0015739547088742256 0.002696078270673752 0.02539546601474285 0.07497072219848633
+0.07527400553226471 0.0004550545127131045 0.0005810305010527372 0.0375884510576725 0.05526735633611679
+0.06475809961557388 0.004562846384942532 0.0039014241192489862 0.020956626161932945 0.03037010133266449
+0.06233100965619087 0.0023883504327386618 0.003962031565606594 0.03476612642407417 0.08651132136583328
+0.06648121774196625 0.0007178991800174117 0.0004626695881597698 0.03424350917339325 0.09915225207805634
+0.07388654351234436 0.005837632808834314 0.0025690181646496058 0.029166996479034424 0.10544726997613907
+0.06872131675481796 0.004677369259297848 0.0019635437056422234 0.050875596702098846 0.05783437564969063
+0.06130209192633629 0.00018197018653154373 0.0006711188470944762 0.03157597780227661 0.04482555389404297
+0.06961771100759506 0.003467345144599676 0.0019435243448242545 0.04836706072092056 0.0681636780500412
+0.056111253798007965 0.003531873691827059 0.002162796212360263 0.023406244814395905 0.05525963008403778
+0.09128304570913315 0.0004459956253413111 0.000805647112429142 0.02744479849934578 0.15668785572052002
+0.07680293917655945 0.0012515587732195854 0.0022129749413579702 0.04573849216103554 0.09248347580432892
+0.07770884037017822 0.003137789200991392 0.003715269733220339 0.03896642103791237 0.05789496749639511
+0.06338945031166077 0.0008759694756008685 0.0026866798289120197 0.015478159300982952 0.05114084854722023
+0.07327096164226532 0.001481174840591848 0.00038753097760491073 0.015525136142969131 0.06823037564754486
+0.07238928973674774 0.005977926775813103 0.0036507025361061096 0.03930458799004555 0.07249903678894043
+0.06650041043758392 0.002364299027249217 0.0034427016507834196 0.02988217957317829 0.09421860426664352
+0.09012208878993988 0.0012306551216170192 0.0006954834097996354 0.03463917598128319 0.1719406247138977
+0.0762292891740799 0.007751849479973316 0.0037723074201494455 0.03223637118935585 0.10682093352079391
+0.06614364683628082 0.004184344783425331 0.0032916483469307423 0.022278809919953346 0.04791639372706413
+0.0854051411151886 0.000379773584427312 0.0002957041433546692 0.055047448724508286 0.039968740195035934
+0.08103509247303009 0.0071084885857999325 0.002940793987363577 0.012200084514915943 0.08314960449934006
+0.07350556552410126 0.006346793379634619 0.0036793393082916737 0.015504106879234314 0.09139330685138702
+0.0717272087931633 0.00016282354772556573 0.0008400678634643555 0.03097257763147354 0.1101241484284401
+0.06991294026374817 0.0047218697145581245 0.002933282870799303 0.04078874737024307 0.10018948465585709
+0.07584837824106216 0.007324064616113901 0.004862651228904724 0.03304954245686531 0.09182369709014893
+0.08085018396377563 0.0008110612398013473 0.001439285813830793 0.025575285777449608 0.08177366107702255
+0.08611947298049927 0.00294159771874547 0.0027780120726674795 0.02590910904109478 0.08770330995321274
+0.08190363645553589 0.006479667034000158 0.005460903514176607 0.022673197090625763 0.09558383375406265
+0.06141276657581329 0.0013346136547625065 0.002261369489133358 0.030343296006321907 0.08220979571342468
+0.07326748967170715 0.0015213785227388144 0.0016085471725091338 0.033063195645809174 0.07254324853420258
+0.08615566045045853 0.005360539071261883 0.00404408760368824 0.0336555652320385 0.12143759429454803
+0.08085883408784866 0.0021111860405653715 0.002203426556661725 0.038074932992458344 0.09940771758556366
+0.06893574446439743 0.000586592243053019 0.0006773262284696102 0.02081151120364666 0.08922266215085983
+0.0704469233751297 0.00461886590346694 0.0022134282626211643 0.017476746812462807 0.07358130812644958
+0.08496715873479843 0.0031806526239961386 0.0017862181412056088 0.019360940903425217 0.11678078025579453
+0.07579972594976425 0.00014103100693318993 0.00041807832894846797 0.021787112578749657 0.11858896911144257
+0.07598085701465607 0.0026940114330500364 0.001527942018583417 0.029695138335227966 0.0829457938671112
+0.07394671440124512 0.0035277516581118107 0.0023048596922308207 0.028025874868035316 0.05843917280435562
+0.08426445722579956 0.00040748342871665955 0.0005850106826983392 0.021065032109618187 0.10750634968280792
+0.06979164481163025 0.0017947371816262603 0.0018515317933633924 0.028809107840061188 0.10313840210437775
+0.07261651754379272 0.003256681840866804 0.0030346824787557125 0.02686973847448826 0.09803947061300278
+0.07557802647352219 0.0003033309185411781 0.0008051292388699949 0.025585222989320755 0.08148019015789032
+0.07990618050098419 0.0017745117656886578 0.0024508098140358925 0.027408502995967865 0.08089858293533325
+0.06339522451162338 0.0025059024337679148 0.003540723118931055 0.029238494113087654 0.06527939438819885
+0.07192370295524597 0.0001991310273297131 0.0009665326215326786 0.024324411526322365 0.10535082221031189
+0.06814174354076385 0.0022944791708141565 0.0033913569059222937 0.04143444821238518 0.10032276064157486
+0.07080301642417908 0.0017770095728337765 0.0023758059833198786 0.03416819125413895 0.0657043606042862
+0.06214689463376999 0.00021292197925504297 0.0010517328046262264 0.04753267765045166 0.06457746773958206
+0.07527071237564087 0.0015595507575199008 0.002269021002575755 0.01910296641290188 0.11596231907606125
+0.05356747657060623 0.0003761465195566416 0.0005981519934721291 0.025892669335007668 0.05309079959988594
+0.06446477770805359 0.00035984048736281693 0.0010815265122801065 0.04117312282323837 0.05370241776108742
+0.05627226456999779 0.0004758042632602155 0.0006410685018636286 0.022076088935136795 0.06096963956952095
+0.05784042924642563 0.00011317033931845799 0.0004774722328875214 0.019450388848781586 0.05572505295276642
+0.10225997865200043 9.603788930689916e-05 0.0005330980056896806 0.015864236280322075 0.14478862285614014
+0.07032816112041473 9.005775791592896e-05 0.0007541773957200348 0.03271510452032089 0.09226153045892715
+0.09080354869365692 0.00019641024118755013 0.0006228430429473519 0.036555707454681396 0.06476323306560516
+0.06202751398086548 0.000672413210850209 0.0007120619993656874 0.02871870994567871 0.06131567433476448
+0.07493001967668533 0.00030541245359927416 0.000826914154458791 0.01564352586865425 0.11067942529916763
+0.08017829060554504 0.000534455815795809 0.0008330236887559295 0.020331325009465218 0.12335038930177689
+0.07932519912719727 0.0016958311898633838 0.003238135715946555 0.047406211495399475 0.10315266251564026
+0.07737445831298828 0.00026710794190876186 0.001548480475321412 0.04512998089194298 0.08531112223863602
+0.08829262852668762 0.0018858402036130428 0.0029245158657431602 0.03080149181187153 0.11303000152111053
+0.06483607739210129 0.0014123653527349234 0.0032451082952320576 0.05354563519358635 0.0755699872970581
+0.06707267463207245 0.00048505759332329035 0.0005829626461490989 0.023159010335803032 0.06615001708269119
+0.08023758977651596 0.0034109149128198624 0.0036015219520777464 0.03511135280132294 0.10779640823602676
+0.0613996647298336 0.0011666709324344993 0.0019468454411253333 0.03056272678077221 0.05426822975277901
+0.06643659621477127 0.001017782837152481 0.0009298754157498479 0.01630130037665367 0.08785055577754974
+0.07898905128240585 0.003310607746243477 0.0028750854544341564 0.01683175191283226 0.11945678293704987
+0.06725424528121948 0.0005750926211476326 0.0005752905854023993 0.028565706685185432 0.0688043013215065
+0.07379446923732758 0.0007975884946063161 0.0012391459895297885 0.03234146162867546 0.09094969183206558
+0.06253355741500854 0.0012833056971430779 0.0009734867489896715 0.026125207543373108 0.0860249400138855
+0.052522413432598114 0.00023389438865706325 0.00030521603184752166 0.018565528094768524 0.05189577117562294
+0.07131839543581009 0.00018314330372959375 0.00094546441687271 0.03934945911169052 0.06180727109313011
+0.05914369598031044 0.00028559108613990247 0.0005702811176888645 0.03057173639535904 0.05123688653111458
+0.08064422011375427 0.0001508641871623695 0.0007319322321563959 0.03263477608561516 0.1210564598441124
+0.05844848230481148 0.00011917907249880955 0.0005467709270305932 0.032658644020557404 0.08660371601581573
+0.05866033211350441 9.282110841013491e-05 0.0005110367783345282 0.028494365513324738 0.04569099470973015
+0.07202275097370148 0.00012904965842608362 0.000575540994759649 0.021523360162973404 0.05846943333745003
+0.06289741396903992 0.0001620859547983855 0.0004272530204616487 0.018269240856170654 0.05366111919283867
+0.05556320399045944 0.00018216928583569825 0.000848111929371953 0.02605428732931614 0.07275866717100143
+0.06782205402851105 9.804962610360235e-05 0.0008911332115530968 0.04670204594731331 0.09883002936840057
+0.06224077567458153 0.0004330860683694482 0.0005320179625414312 0.04601068049669266 0.06759730726480484
+0.07273626327514648 0.0007864658255130053 0.0010028793476521969 0.021204089745879173 0.08987735211849213
+0.06305524706840515 0.00014808791456744075 0.0005067615420557559 0.03521687164902687 0.0680728554725647
+0.0781867653131485 0.0007096939953044057 0.0006076924619264901 0.023220248520374298 0.06678367406129837
+0.06967609375715256 0.0009951874380931258 0.0006612157449126244 0.019710006192326546 0.08077873289585114
+0.06700702011585236 0.00020060106180608273 0.0004520202928688377 0.024141473695635796 0.07544395327568054
+0.05849155783653259 0.0003004776663146913 0.000714797992259264 0.021719353273510933 0.06554722040891647
+0.07868719846010208 0.00045729006524197757 0.0007820846512913704 0.02270386554300785 0.126607745885849
+0.07993613183498383 0.00013182294787839055 0.00046588972327299416 0.025676842778921127 0.09548354893922806
+0.05909587815403938 0.0002161264419555664 0.0004055944154970348 0.022804686799645424 0.0574677437543869
+0.07609318941831589 0.0003894468827638775 0.0003698984219226986 0.01802803948521614 0.07495476305484772
+0.07226534932851791 0.00029186811298131943 0.0004894774756394327 0.017007825896143913 0.07805798947811127
+0.05126838758587837 0.00014575725072063506 0.0006327806040644646 0.019322490319609642 0.07307733595371246
+0.0853569433093071 0.0002444967976771295 0.0010967556154355407 0.035040583461523056 0.11174697428941727
+0.06301475316286087 0.0009192335419356823 0.0005843854160048068 0.03602571412920952 0.0841449722647667
+0.06250220537185669 0.0011479074601083994 0.0007825400680303574 0.020762545987963676 0.07817201316356659
+0.06542626768350601 0.00017051414761226624 0.0005941062700003386 0.018865441903471947 0.07282184064388275
+0.05867896229028702 0.0009189968113787472 0.0010856669396162033 0.01894257590174675 0.04545868560671806
+0.05293945595622063 0.001459232298657298 0.0014528295723721385 0.024513382464647293 0.05365672707557678
+0.07337198406457901 0.0001675850071478635 0.00046502819168381393 0.030121369287371635 0.08295966684818268
+0.08057625591754913 0.000934918993152678 0.0014711612602695823 0.022457430139183998 0.10648059099912643
+0.06430202722549438 0.0011171007063239813 0.0017732848646119237 0.022590918466448784 0.08753988891839981
+0.04708877578377724 0.0003017164708580822 0.0005406956770457327 0.025824908167123795 0.053499940782785416
+0.07060186564922333 0.001170396339148283 0.0009018463315442204 0.026668306440114975 0.10876823961734772
+0.06755216419696808 0.0010241266572847962 0.0006565317162312567 0.06222177669405937 0.047457557171583176
+0.07478020340204239 0.0001631604100111872 0.0005034884088672698 0.020380539819598198 0.0642659068107605
+0.07429206371307373 0.0004960524966008961 0.0009392327629029751 0.019071824848651886 0.0817306861281395
+0.06340711563825607 0.00020471253083087504 0.0003408755292184651 0.030709898099303246 0.07252901047468185
+0.08365403115749359 0.00017745506193023175 0.0009157840977422893 0.03465499356389046 0.13066978752613068
+0.07398761063814163 0.00013823367771692574 0.0004610098258126527 0.022779470309615135 0.09394589811563492
+0.060884345322847366 0.0002664531930349767 0.0003138966567348689 0.01514364778995514 0.05013054609298706
+0.04841461777687073 0.00014514970825985074 0.00029905338305979967 0.013491898775100708 0.03332364931702614
+0.09415994584560394 0.00032703502802178264 0.0008632905664853752 0.019995246082544327 0.0786634162068367
+0.06378714740276337 0.00018433919467497617 0.0009991948027163744 0.03225317597389221 0.057570941746234894
+0.07759188115596771 0.0005336917238309979 0.0005324899684637785 0.023282339796423912 0.12298794090747833
+0.0835462212562561 0.0008785875397734344 0.000690273882355541 0.024809909984469414 0.14634853601455688
+0.06330141425132751 0.0004576239443849772 0.0007889795815572143 0.03136099874973297 0.09578844904899597
+0.06200335547327995 0.00021061314328107983 0.0007264101877808571 0.02797694504261017 0.05949702858924866
+0.08980277180671692 0.0004603664274327457 0.0006124228821136057 0.032588157802820206 0.05983702093362808
+0.06916211545467377 0.00120399403385818 0.0002902815176639706 0.014627828262746334 0.042417433112859726
+0.05194878205657005 0.00151066726539284 0.0006862033624202013 0.01729024574160576 0.07238457351922989
+0.07632062584161758 0.0004526035045273602 0.0015537296421825886 0.03807660564780235 0.12244163453578949
+0.06672586500644684 0.0011060910765081644 0.0010051247663795948 0.04524435102939606 0.09272032976150513
+0.062216226011514664 0.004091036971658468 0.0019783771131187677 0.024615151807665825 0.06733212620019913
+0.07800859212875366 0.0015742727555334568 0.0015102997422218323 0.02904331311583519 0.0777156874537468
+0.07907512784004211 0.0007838534656912088 0.000661111727822572 0.02047327533364296 0.07384537905454636
+0.06407696008682251 0.004187325481325388 0.0031896685250103474 0.025413284078240395 0.0635743960738182
+0.06303684413433075 0.0015482526505365968 0.0015318092191591859 0.02835347130894661 0.08625605702400208
+0.064754419028759 0.0007111074519343674 0.0010056983446702361 0.06719571352005005 0.0888395830988884
+0.044135719537734985 0.003638548543676734 0.0036909254267811775 0.03063119202852249 0.03604297339916229
+0.06228381022810936 0.0013218035455793142 0.0022751849610358477 0.024869011715054512 0.06434577703475952
+0.06668619066476822 0.001012412249110639 0.0004298185231164098 0.017313484102487564 0.07004871219396591
+0.088056780397892 0.004527125507593155 0.0022333126980811357 0.023806491866707802 0.04590357467532158
+0.06065244972705841 0.001894654706120491 0.0011713854037225246 0.014386477880179882 0.04865788668394089
+0.07756125926971436 0.0002857467334251851 0.0005121864378452301 0.019506249576807022 0.12552432715892792
+0.06449438631534576 0.0022560954093933105 0.0008146279724314809 0.030805982649326324 0.08062301576137543
+0.05199041962623596 0.0020256966818124056 0.0007716635009273887 0.043292127549648285 0.057954274117946625
+0.06184752658009529 0.0005541479331441224 0.0003844423627015203 0.0385131798684597 0.058790020644664764
+0.08286212384700775 0.00042651445255614817 0.0003259487566538155 0.013283521868288517 0.06435847282409668
+0.06316012144088745 0.0021027936600148678 0.0010934481397271156 0.012976945377886295 0.03736172616481781
+0.0656822994351387 0.0015221296343952417 0.001619096496142447 0.02276221662759781 0.08365298807621002
+0.06574652343988419 0.00015361035184469074 0.0009340266115032136 0.03867980092763901 0.10465142130851746
+0.07770310342311859 0.0025636237114667892 0.0017818680498749018 0.036522313952445984 0.1299779862165451
+0.07079482078552246 0.0026027201674878597 0.0031379142310470343 0.02701544016599655 0.09590090066194534
+0.06656689941883087 0.00012540635361801833 0.0008419573423452675 0.025618648156523705 0.0504562184214592
+0.062487319111824036 0.002853990765288472 0.0008619922446087003 0.019524484872817993 0.0601106658577919
+0.06683458387851715 0.00416614068672061 0.0019128138665109873 0.018810110166668892 0.04942350462079048
+0.05099944397807121 0.0005995873361825943 0.0008021926623769104 0.02125541865825653 0.03891533613204956
+0.07108931243419647 0.0013800731394439936 0.0009560397593304515 0.03002448007464409 0.10552303493022919
+0.06218358874320984 0.003939870744943619 0.0014828962739557028 0.04492238909006119 0.08271478116512299
+0.0767505019903183 0.00235104002058506 0.001396461739204824 0.03168788179755211 0.08414057642221451
+0.07717282325029373 8.77270576893352e-05 0.0004439930198714137 0.0581577904522419 0.08458827435970306
+0.0662534087896347 0.0027506595943123102 0.0007346026250161231 0.013774855062365532 0.062156498432159424
+0.06476163119077682 0.004400331526994705 0.0018454379169270396 0.01906895823776722 0.08195202052593231
+0.05601781606674194 0.0008376544574275613 0.0009146653464995325 0.02359645627439022 0.06912200897932053
+0.06517469137907028 0.0013831504620611668 0.0010887369280681014 0.036934442818164825 0.07331031560897827
+0.07434754073619843 0.004634047858417034 0.003811946138739586 0.027640346437692642 0.10339101403951645
+0.0718834176659584 0.000994843547232449 0.0012110504321753979 0.025723345577716827 0.09468647837638855
+0.056899648159742355 0.0012612127466127276 0.0010306687327101827 0.021679643541574478 0.05377352237701416
+0.06488529592752457 0.003631404833868146 0.0026119232643395662 0.02176193706691265 0.04291360080242157
+0.06335372477769852 0.0010187821462750435 0.0009181411005556583 0.019057296216487885 0.06647775322198868
+0.07837170362472534 0.00046415196266025305 0.00051168876234442 0.02344399131834507 0.09696300327777863
+0.05379093438386917 0.0019719814881682396 0.0009147298405878246 0.03519382327795029 0.047143254429101944
+0.04972977191209793 0.0017009174916893244 0.0010648607276380062 0.03641046583652496 0.05764629319310188
+0.06426414847373962 0.00034137244801968336 0.0006967547815293074 0.02370942197740078 0.07538367062807083
+0.058149877935647964 0.0007322597084566951 0.0005121935391798615 0.021780652925372124 0.06699243932962418
+0.08053821325302124 0.0019463853677734733 0.0010055204620584846 0.020878296345472336 0.07879526913166046
+0.0561162605881691 0.00093026552349329 0.0006433132220990956 0.038065213710069656 0.04545315355062485
+0.06249864399433136 0.0002383526152698323 0.000673917296808213 0.022151175886392593 0.062495578080415726
+0.052274975925683975 0.0010742039885371923 0.0014069457538425922 0.023941004648804665 0.05906789377331734
+0.06950578093528748 0.000511687365360558 0.0008772097062319517 0.024687860161066055 0.10167581588029861
+0.07542562484741211 0.00023678744037169963 0.000563857494853437 0.025831997394561768 0.10569466650485992
+0.07499057799577713 0.0005750166019424796 0.0008835938642732799 0.022102631628513336 0.09454900026321411
+0.060541585087776184 0.00019609762239269912 0.000353956944309175 0.015696141868829727 0.03666253760457039
+0.04777801036834717 0.0001706384791759774 0.0004794099077116698 0.015364175662398338 0.028275758028030396
+0.049273353070020676 0.00014837282651569694 0.0003871456137858331 0.01650291495025158 0.059380270540714264
+0.0561874583363533 0.00016893699648790061 0.0007203122368082404 0.02472621761262417 0.05132053419947624
+0.05632404983043671 0.0002190309314755723 0.0009635350434109569 0.025409899652004242 0.05456266924738884
+0.054631806910037994 0.0002332588192075491 0.0005867155850864947 0.021587004885077477 0.06686624884605408
+0.047138139605522156 0.0005502073909156024 0.000579273619223386 0.01835607923567295 0.04454086720943451
+0.05573225021362305 0.00025166774867102504 0.0005545218009501696 0.02008424513041973 0.05853138118982315
+0.06417892128229141 0.0002169207000406459 0.0008468091837130487 0.023838916793465614 0.05073189362883568
+0.050002798438072205 0.0003770120965782553 0.0007287064217962325 0.027921821922063828 0.04839979112148285
+0.07621514797210693 0.00011804685345850885 0.000647722918074578 0.025417139753699303 0.13501764833927155
+0.06962016969919205 0.00016786163905635476 0.0008782242657616735 0.02632044069468975 0.09893535077571869
+0.07619384676218033 0.0002099257690133527 0.0010285840835422277 0.026090769097208977 0.09962811321020126
+0.056005217134952545 8.366691326955333e-05 0.0010738017736002803 0.02391459420323372 0.03454665094614029
+0.06209680438041687 0.00040828128112480044 0.0005494813667610288 0.02176089771091938 0.055945221334695816
+0.07095743715763092 0.0010032432619482279 0.0006288131698966026 0.01656695082783699 0.08736468106508255
+0.0623500682413578 0.0007055990863591433 0.0013661660486832261 0.022265488281846046 0.08107927441596985
+0.07650666683912277 0.00024983700131997466 0.001161872292868793 0.03564540296792984 0.08232379704713821
+0.06643014401197433 0.002309598261490464 0.000663574377540499 0.029961412772536278 0.10826556384563446
+0.05493722856044769 0.00230402871966362 0.0009028365602716804 0.023249566555023193 0.03947931528091431
+0.06052166968584061 0.0002009210002142936 0.0004037268226966262 0.018864763900637627 0.059398334473371506
+0.06529080122709274 0.0014199971919879317 0.0013120839139446616 0.018057677894830704 0.07291535288095474
+0.06576815992593765 0.002053049858659506 0.0017827927367761731 0.03007984720170498 0.08096665889024734
+0.06267464905977249 0.00024202409258577973 0.0006475654663518071 0.04600357264280319 0.08600121736526489
+0.07754895091056824 0.0008340072236023843 0.001833809190429747 0.03360399976372719 0.09293056279420853
+0.06664222478866577 0.0009944351622834802 0.0020210049115121365 0.04423696547746658 0.07417718321084976
+0.055861685425043106 0.00010533586464589462 0.0007522321538999677 0.026424948126077652 0.03742437809705734
+0.07180938124656677 0.0009796156082302332 0.0007717547123320401 0.08580216765403748 0.061039138585329056
+0.05486920103430748 0.0013105273246765137 0.0013620404060930014 0.01890953630208969 0.053887538611888885
+0.06372833251953125 0.00022029275714885443 0.0006020835717208683 0.0223702359944582 0.07381144911050797
+0.048460863530635834 0.0008319297339767218 0.0003219398786313832 0.025413796305656433 0.06341235339641571
+0.06652743369340897 0.001670793746598065 0.0006875174003653228 0.028800731524825096 0.06848283857107162
+0.06150797754526138 0.000859157822560519 0.000668802298605442 0.020191548392176628 0.07395412027835846
+0.05237699672579765 0.00019881052139680833 0.00034343861625529826 0.016363734379410744 0.03466847538948059
+0.07573171705007553 0.0015223040245473385 0.0006737381918355823 0.015976205468177795 0.0927039161324501
+0.053402386605739594 0.0014334123115986586 0.0008612585952505469 0.023337475955486298 0.05546131730079651
+0.05466648191213608 0.00011049157183151692 0.00042843754636123776 0.03282923623919487 0.05817633867263794
+0.0564645417034626 0.0014007380232214928 0.0025360090658068657 0.029823865741491318 0.0752163678407669
+0.06957195699214935 0.0008187816711142659 0.002031629206612706 0.021460020914673805 0.09505726397037506
+0.05290365591645241 0.000490586506202817 0.0003248740977142006 0.028571145609021187 0.06811759620904922
+0.07745178788900375 0.00193891953676939 0.0015380315016955137 0.04956221207976341 0.07202329486608505
+0.06719034910202026 0.0004905066452920437 0.00038546323776245117 0.018137065693736076 0.026321902871131897
+0.06493318825960159 0.00027702542138285935 0.0005445837159641087 0.017458265647292137 0.07426287233829498
+0.09575142711400986 0.00040942360647022724 0.0007222550339065492 0.028632743284106255 0.14982348680496216
+0.06431354582309723 0.00029884561081416905 0.0023329395335167646 0.07454191893339157 0.058753371238708496
+0.061644598841667175 0.0008348129340447485 0.0006477686110883951 0.035011857748031616 0.06365907192230225
+0.05643802881240845 0.000815515115391463 0.0008371655712835491 0.02221071533858776 0.045182887464761734
+0.0699804276227951 0.00023812845756765455 0.0003397404798306525 0.017675265669822693 0.08316117525100708
+0.04677243530750275 0.002162084449082613 0.0020080155227333307 0.03358571603894234 0.03438505157828331
+0.05720144882798195 0.0014120059786364436 0.0017668631626293063 0.03595374524593353 0.06960506737232208
+0.07516859471797943 0.0004573546757455915 0.0006744199781678617 0.022997410967946053 0.06774383783340454
+0.05425390601158142 0.0021606122609227896 0.0026632642839103937 0.01607654057443142 0.06184874475002289
+0.04857965558767319 0.0004534422478172928 0.0006921601016074419 0.01623990200459957 0.05351032316684723
+0.052090615034103394 0.0009542630286887288 0.0008383888052776456 0.018669873476028442 0.04059470444917679
+0.056715287268161774 0.0016963104717433453 0.0014867590507492423 0.027489323168992996 0.05658717080950737
+0.056296397000551224 0.0001856095768744126 0.00034499799949117005 0.01778329536318779 0.07859459519386292
+0.062302689999341965 0.0008106976747512817 0.001016163150779903 0.016034666448831558 0.0746365487575531
+0.06563082337379456 0.0008656579884700477 0.0007617473020218313 0.025026781484484673 0.09663259238004684
+0.062101446092128754 0.00015450891805812716 0.00044871994759887457 0.03262624889612198 0.05783764272928238
+0.0663885697722435 0.00017289166862610728 0.00033593628904782236 0.023310724645853043 0.04155990481376648
+0.06600967049598694 0.00044813257409259677 0.0004455690796021372 0.01723601296544075 0.05240330100059509
+0.06511413305997849 0.0006685158587060869 0.00033745975815691054 0.0181740615516901 0.1037089079618454
+0.06453417986631393 0.0005749688134528697 0.0010528372367843986 0.035910606384277344 0.08096246421337128
+0.06425316631793976 0.00014360868954099715 0.000817980500869453 0.04082563892006874 0.07557754963636398
+0.07150349020957947 0.001150893047451973 0.0012987948721274734 0.027120888233184814 0.07476762682199478
+0.06216663122177124 0.000644038780592382 0.0008357883780263364 0.02316930517554283 0.09831791371107101
+0.05897628515958786 0.000353322975570336 0.0009854878298938274 0.022641662508249283 0.05924512818455696
+0.06842561066150665 0.0012854273663833737 0.0028519153129309416 0.027654090896248817 0.060248762369155884
+0.0685662105679512 0.00023326317023020238 0.0012014107778668404 0.024754682555794716 0.05875789374113083
+0.053620316088199615 0.000961964949965477 0.0006064707413315773 0.02483524940907955 0.06907179206609726
+0.06300771236419678 0.0011802552035078406 0.00067421660060063 0.017351048067212105 0.07616814225912094
+0.05326821655035019 0.00012773550406564027 0.0005852991016581655 0.027331862598657608 0.05569034814834595
+0.07048477977514267 0.0003537716984283179 0.0008892596233636141 0.031405624002218246 0.09893456846475601
+0.055710263550281525 0.0001774157426552847 0.00033876224188134074 0.022774428129196167 0.06800395250320435
+0.07698942720890045 8.12393773230724e-05 0.0012747137807309628 0.0170493945479393 0.05147889629006386
+0.07294388115406036 0.00011500986147439107 0.0011320984922349453 0.016484679654240608 0.03712606430053711
+0.07009570300579071 0.0006341154221445322 0.0002908338501583785 0.013617229647934437 0.06317488104104996
+0.05670331418514252 0.0010082569206133485 0.0014613332459703088 0.028629856184124947 0.08204443007707596
+0.08652260154485703 9.815413795877248e-05 0.0013066590763628483 0.044784065335989 0.15550754964351654
+0.0709645226597786 0.001979711465537548 0.001098324777558446 0.029860908165574074 0.11241272836923599
+0.07259486615657806 0.0024233232252299786 0.0019168425351381302 0.03806258365511894 0.08610546588897705
+0.0696098804473877 0.00025453948182985187 0.0005893263150937855 0.018764346837997437 0.03399818018078804
+0.07082410156726837 0.001956257037818432 0.0012077174615114927 0.014664976857602596 0.03635039180517197
+0.0640958920121193 0.0029181570280343294 0.0020091405604034662 0.01421711128205061 0.04235195368528366
+0.06286241114139557 0.0005593406385742128 0.0011133856605738401 0.05967497453093529 0.05612090229988098
+0.050654638558626175 0.001154226716607809 0.000804243260063231 0.02944922260940075 0.0743708461523056
+0.09130749106407166 0.003491287352517247 0.0016028000973165035 0.03831477835774422 0.17007918655872345
+0.06699656695127487 0.0018289609579369426 0.00137609604280442 0.07564836740493774 0.09530804306268692
+0.07446403056383133 0.0001176674704765901 0.0005783483502455056 0.019427696242928505 0.10562624782323837
+0.0693107470870018 0.0022878714371472597 0.0003670157748274505 0.01946207880973816 0.023546719923615456
+0.10853219777345657 0.004185612313449383 0.0005895454087294638 0.05347207188606262 0.050236642360687256
+0.05701399967074394 0.0014169674832373857 0.0005363511736504734 0.01530513633042574 0.034477926790714264
+0.068815216422081 0.00017598345584701747 0.0005082158022560179 0.025372721254825592 0.12028709799051285
+0.06527392566204071 0.0022275042720139027 0.0013076638570055366 0.04980937018990517 0.10675421357154846
+0.09263665974140167 0.004010180477052927 0.0014719261089339852 0.04264356195926666 0.16629056632518768
+0.07128798216581345 0.0017317793099209666 0.0014091074699535966 0.03872436657547951 0.1155911460518837
+0.06100153177976608 0.00027676683384925127 0.0005534100346267223 0.01647324673831463 0.03618885204195976
+0.07190420478582382 0.0034476860892027617 0.0011360507924109697 0.018161574378609657 0.033346131443977356
+0.06891505420207977 0.004041756968945265 0.0009492447134107351 0.016621388494968414 0.04497142508625984
+0.04616250842809677 0.0006304926937445998 0.0006128437235020101 0.018465576693415642 0.06253483146429062
+0.06468406319618225 0.0008468325831927359 0.0014259167946875095 0.030542604625225067 0.08368384838104248
+0.0654384046792984 0.0025689532049000263 0.002067022491246462 0.03764853999018669 0.08948300033807755
+0.05073472857475281 0.0016316709807142615 0.0012982572661712766 0.025592871010303497 0.06340674310922623
+0.05745946615934372 0.00018778617959469557 0.0005712371785193682 0.018045345321297646 0.07435096800327301
+0.07048313319683075 0.0018735001794993877 0.0017129243351519108 0.014139814302325249 0.07991816848516464
+0.07088855654001236 0.0018412485951557755 0.0019101096550002694 0.05348382890224457 0.041845668107271194
+0.05834713205695152 0.0002467080776114017 0.0004432243586052209 0.01859174482524395 0.057556916028261185
+0.06408638507127762 0.0012614354491233826 0.0006640852079726756 0.018636832013726234 0.08639249205589294
+0.07060492038726807 0.0012658386258408427 0.0009452792583033442 0.031996626406908035 0.11373971402645111
+0.05529904365539551 0.00039807724533602595 0.0008316531893797219 0.037243057042360306 0.07848741859197617
+0.06508676707744598 0.0002012015029322356 0.0005834242911078036 0.030492862686514854 0.09698835015296936
+0.050424277782440186 0.00011998883564956486 0.0006024328176863492 0.020854009315371513 0.047321099787950516
+0.050847310572862625 0.0003411388606764376 0.0009977375157177448 0.01894170604646206 0.046992748975753784
+0.06046349182724953 0.000990899046882987 0.0004648214380722493 0.02092691697180271 0.02779306285083294
+0.05170624330639839 0.0009351706248708069 0.0003808251640293747 0.020166706293821335 0.022623280063271523
+0.05357615649700165 0.00026031830930151045 0.000500711495988071 0.02120477519929409 0.06891000270843506
+0.07903631031513214 0.00029916997300460935 0.000741808267775923 0.029329929500818253 0.1445455402135849
+0.07248228788375854 0.0008944409200921655 0.0007387047517113388 0.03730103000998497 0.1267836093902588
+0.07403530180454254 0.0009468220523558557 0.0008579806890338659 0.026871517300605774 0.09398176521062851
+0.06538072228431702 0.00015229421842377633 0.0006544576026499271 0.051873933523893356 0.05795283243060112
+0.060374170541763306 0.0009673212771303952 0.0005545661551877856 0.01536914985626936 0.06829139590263367
+0.07045834511518478 0.003069888101890683 0.002513946034014225 0.015409255400300026 0.050695158541202545
+0.07285843044519424 0.0012732672039419413 0.0026858816854655743 0.0164125207811594 0.07648453861474991
+0.05260821059346199 0.0006206340622156858 0.0005628634826280177 0.022231750190258026 0.04032508656382561
+0.054010286927223206 0.0043898229487240314 0.0007403700728900731 0.023215750232338905 0.05145411565899849
+0.05339399725198746 0.003059827256947756 0.0006963422638364136 0.021816235035657883 0.07135636359453201
+0.059300582855939865 0.0001827713567763567 0.0018324877601116896 0.028058011084794998 0.09296914935112
+0.06234519183635712 0.0016627191798761487 0.0038677253760397434 0.04395497962832451 0.08349298685789108
+0.05324609577655792 0.000510466517880559 0.0007378358277492225 0.05597924441099167 0.030652722343802452
+0.06859389692544937 0.0004653106152545661 0.003133353777229786 0.030943967401981354 0.04882987216114998
+0.06555893272161484 0.00035220064455643296 0.0035940948873758316 0.0213779304176569 0.04290783405303955
+0.05909637361764908 0.0007311771623790264 0.0007376219145953655 0.02410171367228031 0.03733169659972191
+0.06630781292915344 0.0022189724259078503 0.0010921051725745201 0.03169732540845871 0.09804072231054306
+0.07052730768918991 0.000504339812323451 0.0007613056222908199 0.027053775265812874 0.12038791924715042
+0.06331197172403336 0.0004983452381566167 0.0013876438606530428 0.01847805082798004 0.0983482226729393
+0.05559566244482994 0.0009373400826007128 0.0013490943238139153 0.019296279177069664 0.05833493545651436
+0.07384927570819855 0.00015450331557076424 0.0003224758547730744 0.017187925055623055 0.09002922475337982
+0.060814570635557175 0.0005253865383565426 0.0012874214444309473 0.044402312487363815 0.046430375427007675
+0.04826895892620087 0.000525265873875469 0.0016138432547450066 0.015419451519846916 0.046498823910951614
+0.05793209373950958 0.00013099476927891374 0.0009743800037540495 0.01210680790245533 0.06455890834331512
+0.04793316870927811 0.000600685307290405 0.000735717301722616 0.0196062121540308 0.051695071160793304
+0.059038519859313965 0.0012326689902693033 0.0008000460802577436 0.026563825085759163 0.07905128598213196
+0.054983340203762054 0.0010084557579830289 0.0008572663064114749 0.026777157559990883 0.055864714086055756
+0.04511547088623047 0.0002231746620964259 0.0011610091896727681 0.014787168242037296 0.05082045495510101
+0.057573217898607254 0.0009188750409521163 0.0008615753031335771 0.01802409067749977 0.06648760288953781
+0.059290580451488495 0.0022199612576514482 0.001315723406150937 0.023515846580266953 0.07406013458967209
+0.05433392524719238 0.0010054765734821558 0.0013323916355147958 0.02401754818856716 0.03684306889772415
+0.05781801789999008 0.00026238057762384415 0.002132254419848323 0.02183874137699604 0.0709272176027298
+0.06704966723918915 0.0009532295516692102 0.002188656246289611 0.0238414965569973 0.07894285768270493
+0.05442073568701744 0.0003449611831456423 0.0008690078975632787 0.026227818801999092 0.06796533614397049
+0.06606855243444443 0.0003059418231714517 0.0008140618447214365 0.025265121832489967 0.07738429307937622
+0.06903233379125595 0.0005406824639067054 0.0007984788971953094 0.0222454946488142 0.07421045005321503
+0.06633789092302322 0.00012685306137427688 0.0005356287583708763 0.014265637844800949 0.070000059902668
+0.06034650281071663 0.00028169737197458744 0.00048562948359176517 0.018369227647781372 0.05475173145532608
+0.06392639130353928 0.0002723251236602664 0.0003272054891567677 0.02645433507859707 0.08135629445314407
+0.07348065823316574 0.00012169845285825431 0.0005023632547818124 0.03509437292814255 0.1283058077096939
+0.053308866918087006 0.00013869148097001016 0.00048609028453938663 0.021445421501994133 0.067030169069767
+0.05447795242071152 0.0001363447809126228 0.00030279552447609603 0.02274920418858528 0.059495147317647934
+0.04975919798016548 0.00010110998846357688 0.00034438323928043246 0.015210213139653206 0.03896615654230118
+0.06499574333429337 0.00020803431107196957 0.00027375403442420065 0.04337853193283081 0.05344439670443535
+0.05644291639328003 0.00032175969681702554 0.0007860696641728282 0.020560741424560547 0.049479998648166656
+0.04374835640192032 0.0002777496993076056 0.0009920608717948198 0.03199616074562073 0.04025864973664284
+0.054988645017147064 0.0006957374862395227 0.0004915648023597896 0.026553543284535408 0.08241342753171921
+0.05550863966345787 0.001159549574367702 0.0012326292926445603 0.02920558676123619 0.06214845925569534
+0.049108877778053284 0.0003077606961596757 0.0010266713798046112 0.023472214117646217 0.06979008764028549
+0.05396091192960739 0.0005610674852505326 0.0006507121142931283 0.023153970018029213 0.06419744342565536
+0.052238985896110535 0.0014500853139907122 0.00033692060969769955 0.024265749379992485 0.04680100083351135
+0.05500746890902519 0.0006756891962140799 0.0004351423413027078 0.019010720774531364 0.03934599831700325
+0.07834406197071075 0.00016557591152377427 0.0006630765274167061 0.01893787458539009 0.09819459170103073
+0.07009150832891464 0.00011591618385864422 0.00043678245856426656 0.01845276728272438 0.10102157294750214
+0.05942915752530098 0.0001997736981138587 0.0014077682280912995 0.04375085234642029 0.07646494358778
+0.053547102957963943 0.0009884374449029565 0.00039350317092612386 0.03367733582854271 0.0744827389717102
+0.056817784905433655 0.0015816184459254146 0.0007933261804282665 0.019080905243754387 0.04950003698468208
+0.05827419459819794 0.00014052756887394935 0.00041163197602145374 0.012678888626396656 0.04977134242653847
+0.04628137871623039 0.0015363183338195086 0.0020843814127147198 0.015796774998307228 0.047268372029066086
+0.06372900307178497 0.001415895763784647 0.0024282624945044518 0.020662248134613037 0.06328918039798737
+0.05188420042395592 0.0003421676519792527 0.0004051378055009991 0.019547846168279648 0.062342867255210876
+0.06473296880722046 0.0018222469370812178 0.001621031784452498 0.03560686111450195 0.06342668831348419
+0.05509522929787636 0.000691713357809931 0.0008176207193173468 0.02572816237807274 0.03930166736245155
+0.04913613572716713 0.00039097401895560324 0.0006790247862227261 0.026324357837438583 0.038510724902153015
+0.05809870362281799 0.0010371273383498192 0.0017104700673371553 0.030555833131074905 0.07826845347881317
+0.05572653189301491 0.00023335836885962635 0.0005490804323926568 0.024027345702052116 0.07796423137187958
+0.06500907242298126 0.00045156150008551776 0.0014454235788434744 0.021236073225736618 0.0944916233420372
+0.05715349316596985 0.00029474575421772897 0.0013322359882295132 0.029965044930577278 0.06741300225257874
+0.0626508891582489 0.00015640270430594683 0.0006897688144817948 0.024697650223970413 0.06633715331554413
+0.06685944646596909 0.0006365907611325383 0.0003138917963951826 0.017390746623277664 0.0749889388680458
+0.05499193072319031 0.0009148953249678016 0.0003745104477275163 0.016149627044796944 0.04306266829371452
+0.0539880096912384 0.00042215653229504824 0.0005994639359414577 0.018296685069799423 0.07260968536138535
+0.06036413460969925 0.00016282068099826574 0.0005077955429442227 0.02783256582915783 0.0772852972149849
+0.05294276028871536 0.0010973676107823849 0.0004074011812917888 0.03195281699299812 0.06526850163936615
+0.05468933284282684 0.0012569279642775655 0.0003560946206562221 0.02090296894311905 0.07666313648223877
+0.05364140123128891 0.0002508714096620679 0.0004296165716368705 0.02636856585741043 0.03618350997567177
+0.06382185220718384 0.000489874160848558 0.0011319976765662432 0.020682087168097496 0.06182140111923218
+0.0607808381319046 0.0006494201370514929 0.0009554981952533126 0.017981206998229027 0.052846744656562805
+0.056495506316423416 0.00013324004248715937 0.0006149184191599488 0.02029590681195259 0.08046959340572357
+0.06880495697259903 0.000563439738471061 0.0006830269121564925 0.030231565237045288 0.10200108587741852
+0.06667524576187134 0.0004628548340406269 0.0008448108565062284 0.03334194794297218 0.09556197375059128
+0.0645609050989151 0.00026186200557276607 0.0007948768907226622 0.02855721302330494 0.08096303045749664
+0.058548618108034134 0.0003403724986128509 0.0004653069481719285 0.021778654307127 0.0659983828663826
+0.0584665946662426 0.0010891897836700082 0.00017862375534605235 0.022746039554476738 0.04933600872755051
+0.06554298847913742 0.001127568306401372 0.00034677283838391304 0.018452879041433334 0.04379332438111305
+0.06543445587158203 0.00032496018684469163 0.00045132701052352786 0.015431633219122887 0.05616007000207901
+0.04857596009969711 0.0002501441049389541 0.00046504990314133465 0.02163720689713955 0.05810508131980896
+0.07919714599847794 0.0013211200712248683 0.0005116378306411207 0.031296256929636 0.155729278922081
+0.06360568851232529 0.0019538605120033026 0.0005921876872889698 0.03813600540161133 0.09743920713663101
+0.053057193756103516 0.0011279430473223329 0.0005639940500259399 0.021726571023464203 0.046339236199855804
+0.06572422385215759 0.00011813282617367804 0.0004949159338138998 0.015414491295814514 0.04758356884121895
+0.06512421369552612 0.002527916105464101 0.0007217700476758182 0.016237623989582062 0.02274470031261444
+0.045381058007478714 0.00367238512262702 0.002206098986789584 0.01667908765375614 0.03630693256855011
+0.053641099482774734 0.0003684388066176325 0.0008896152721717954 0.026580102741718292 0.07826311886310577
+0.061951737850904465 0.0021562734618782997 0.0013666197191923857 0.03986746817827225 0.09461086243391037
+0.055187419056892395 0.0046814801171422005 0.003221775172278285 0.03870949521660805 0.07748406380414963
+0.04771287366747856 0.0009904137114062905 0.0017527006566524506 0.027325887233018875 0.05802444741129875
+0.044109687209129333 0.0013205744326114655 0.0004853963910136372 0.019594430923461914 0.03725118562579155
+0.06524865329265594 0.005100814625620842 0.0016919622430577874 0.01644059643149376 0.05043509230017662
+0.062463536858558655 0.002362137893214822 0.0009004001040011644 0.02031015045940876 0.025713622570037842
+0.05968639254570007 0.0002403042744845152 0.0006517131696455181 0.01862434297800064 0.05519593507051468
+0.059174664318561554 0.0022031196858733892 0.001176525722257793 0.0276037584990263 0.055640850216150284
+0.061193905770778656 0.0019507389515638351 0.0009451949736103415 0.03408242017030716 0.09757629036903381
+0.05875147134065628 0.00043473971891216934 0.0007556708296760917 0.028516290709376335 0.09630180150270462
+0.053124357014894485 0.0004400770994834602 0.0010603771079331636 0.04140942916274071 0.06637305021286011
+0.047187864780426025 0.0004052925214637071 0.000408992578741163 0.016411148011684418 0.022115323692560196
+0.05128481611609459 0.00017943244893103838 0.000425025966251269 0.018154267221689224 0.028805585578083992
+0.055558785796165466 0.00014120095875114202 0.0005053848144598305 0.013430491089820862 0.0378837063908577
+0.05780749022960663 0.00018672012083698064 0.00038237927947193384 0.019163867458701134 0.07852592319250107
+0.049025461077690125 0.0002749598352238536 0.0004635157238226384 0.023222822695970535 0.05672124773263931
+0.04912029951810837 0.0001202049243147485 0.0003379393310751766 0.020738141611218452 0.05446456000208855
+0.061882589012384415 0.0004055943281855434 0.0004810952232219279 0.0170432198792696 0.0777420923113823
+0.06777085363864899 0.00035908393329009414 0.0003479829174466431 0.01531265303492546 0.08903823047876358
+0.046900756657123566 0.00016884831711649895 0.0006153483409434557 0.022147396579384804 0.05801407992839813
+0.0624854750931263 0.00026901683304458857 0.0007803775370121002 0.03249451518058777 0.06212805584073067
+0.0636247843503952 0.00019923350191675127 0.0005460606189444661 0.022020231932401657 0.06567694991827011
+0.05227566510438919 0.0006326494039967656 0.001316616078838706 0.015402359887957573 0.05419214442372322
+0.053553514182567596 0.00012500834418460727 0.0005182111053727567 0.015277705155313015 0.03534160181879997
+0.05961379408836365 0.001037290901876986 0.0007349844672717154 0.038377925753593445 0.06052818149328232
+0.05581934005022049 0.0008666327339597046 0.0011531462660059333 0.026538830250501633 0.06648257374763489
+0.05835328251123428 0.00021710744476877153 0.00038981769466772676 0.025194276124238968 0.05110958218574524
+0.052414700388908386 0.001421059831045568 0.0012759827077388763 0.01702054589986801 0.067099429666996
+0.05667931213974953 0.00047658028779551387 0.00045478984247893095 0.018349694088101387 0.05366881936788559
+0.04561128094792366 0.00029090375755913556 0.0012207734398543835 0.027189061045646667 0.047236695885658264
+0.05451105907559395 0.0004694747331086546 0.0009614109294489026 0.025543618947267532 0.05768311023712158
+0.05356581509113312 7.923251541797072e-05 0.00026825847453437746 0.018474940210580826 0.06338636577129364
+0.050463199615478516 0.00035067260614596307 0.0010812383843585849 0.029015500098466873 0.04265771806240082
+0.05550234019756317 7.398191519314423e-05 0.00046334677608683705 0.016421273350715637 0.03219720721244812
+0.04577097296714783 0.0005363616510294378 0.00046555206063203514 0.016627192497253418 0.053024858236312866
+0.047176748514175415 0.0007351316744461656 0.000958526914473623 0.01985231600701809 0.06432484835386276
+0.03752126544713974 0.00010659076360752806 0.0004238580004312098 0.02594306692481041 0.024722205474972725
+0.04360739141702652 0.0010159692028537393 0.0008474615751765668 0.018596256151795387 0.02108025923371315
+0.04533683508634567 0.0006310288445092738 0.0006622692453674972 0.014889376237988472 0.04445083066821098
+0.05582346394658089 0.00021699402714148164 0.000872419448569417 0.025944944471120834 0.06333205848932266
+0.06477096676826477 0.0006799958064220846 0.0014347609831020236 0.03549671918153763 0.09698989987373352
+0.052146732807159424 7.885677769081667e-05 0.0004180474206805229 0.02546095661818981 0.05570521950721741
+0.05765226483345032 0.0006481281016021967 0.0019900957122445107 0.016094645485281944 0.05284825339913368
+0.054999180138111115 0.00010264239244861528 0.000909043475985527 0.022591646760702133 0.04972899332642555
+0.052629418671131134 0.0010245753219351172 0.0007566988933831453 0.04329249635338783 0.03829066827893257
+0.0531393326818943 0.0011488369200378656 0.0011305343359708786 0.023257551714777946 0.05900883674621582
+0.05822243541479111 0.000190827515325509 0.0006176847964525223 0.017893768846988678 0.08189438283443451
+0.06622495502233505 0.000743701180908829 0.0006682004313915968 0.023727720603346825 0.1008748859167099
+0.05330228805541992 0.0005955850356258452 0.0008597951382398605 0.036021340638399124 0.06104963272809982
+0.04560914635658264 0.00019260667613707483 0.0005398937501013279 0.027638984844088554 0.05758832395076752
+0.0487111434340477 0.00010984743857989088 0.0003998558677267283 0.017892582342028618 0.0592169426381588
+0.06403288245201111 0.0002433313784422353 0.0004769030783791095 0.019411562010645866 0.044878751039505005
+0.06289280951023102 0.0004907691036351025 0.000512665428686887 0.024148255586624146 0.042565908282995224
+0.06746654212474823 0.00046041933819651604 0.0012760056415572762 0.02186683379113674 0.10522790253162384
+0.05135335028171539 0.00016346428310498595 0.0013900881167501211 0.04058947414159775 0.06310421228408813
+0.05801504850387573 0.0014573786174878478 0.000830718083307147 0.03901522606611252 0.07158856093883514
+0.05101395770907402 0.0021811234764754772 0.0018241858342662454 0.03770538792014122 0.059904105961322784
+0.05913882702589035 0.00011933473433600739 0.0005496551166288555 0.015295542776584625 0.08127358555793762
+0.05371829494833946 0.002316715894266963 0.0013456001179292798 0.023654745891690254 0.06093582883477211
+0.04838933050632477 0.0022725327871739864 0.0013634959468618035 0.05825145170092583 0.0432489849627018
+0.04495234042406082 0.00012245585094206035 0.000666515901684761 0.018315065652132034 0.04748592525720596
+0.06097416952252388 0.001711843884550035 0.003358179237693548 0.016794301569461823 0.05735312029719353
+0.06550009548664093 0.0004624309076461941 0.0012963006738573313 0.02388005331158638 0.06801207363605499
+0.053509317338466644 0.000861925829667598 0.0006878876592963934 0.03851037099957466 0.045843761414289474
+0.05337138473987579 0.0015809189062565565 0.0013711720239371061 0.02314780279994011 0.06656261533498764
+0.035440996289253235 0.00012623303337022662 0.00023619350395165384 0.01702282577753067 0.027761798352003098
+0.04558633267879486 0.0008489299798384309 0.0008532689535059035 0.01687425747513771 0.05531613901257515
+0.0533549040555954 0.0006980191101320088 0.0005235348944552243 0.023854302242398262 0.0529208704829216
+0.05499463528394699 0.0001320561277680099 0.0003916754212696105 0.026329083368182182 0.059591781347990036
+0.05245605856180191 0.00021330962772481143 0.0003516226133797318 0.019260255619883537 0.048680711537599564
+0.045837391167879105 0.0002942675491794944 0.00023857348423916847 0.013479403220117092 0.0335511639714241
+0.06457863748073578 0.0004275004321243614 0.0005779775674454868 0.04788538068532944 0.08857440203428268
+0.054895348846912384 0.00022855043062008917 0.001354046049527824 0.030483389273285866 0.05716076120734215
+0.05476678907871246 0.0004916961770504713 0.0006683848914690316 0.029653266072273254 0.0755259171128273
+0.05574966222047806 0.0017519781831651926 0.0013312792871147394 0.01736065186560154 0.04432591423392296
+0.061542950570583344 0.0003087589284405112 0.0005991791840642691 0.014289913699030876 0.059940751641988754
+0.04476580768823624 0.0013054613955318928 0.0016222659032791853 0.0182785727083683 0.05125916004180908
+0.04498979076743126 0.0019333618693053722 0.002798160072416067 0.027578111737966537 0.034793004393577576
+0.04798073321580887 0.00010194686910836026 0.00036874102079309523 0.027066359296441078 0.057587090879678726
+0.04662422835826874 0.002384644467383623 0.0030765163246542215 0.01812351495027542 0.025788160040974617
+0.053822822868824005 0.001226960914209485 0.0023933975026011467 0.031149640679359436 0.07652118057012558
+0.04519286006689072 0.000581201515160501 0.0009244066895917058 0.026274656876921654 0.05527445673942566
+0.058278799057006836 0.0023390199057757854 0.0015303051332011819 0.0322641022503376 0.0855553075671196
+0.05907249450683594 0.0005613507237285376 0.0005509324837476015 0.018297307193279266 0.08702748268842697
+0.052684806287288666 0.0003383454168215394 0.0015827689785510302 0.017744891345500946 0.047390166670084
+0.05462917685508728 0.00043470002128742635 0.001296431408263743 0.018543295562267303 0.016749009490013123
+0.05185725539922714 0.00019705183512996882 0.00028129294514656067 0.017587676644325256 0.044999945908784866
+0.04186074435710907 0.0008819385548122227 0.0022037664894014597 0.01808916963636875 0.030914362519979477
+0.052000392228364944 0.0002759082999546081 0.0022310689091682434 0.022943714633584023 0.06378643959760666
+0.0553489625453949 0.0008967327885329723 0.0009512082324363291 0.03229692950844765 0.08824532479047775
+0.04996805265545845 0.0031552871223539114 0.00127795513253659 0.03457034379243851 0.06662791222333908
+0.0538891963660717 0.0017433222383260727 0.001610813313163817 0.026435397565364838 0.04249345511198044
+0.048130542039871216 0.0003070624661631882 0.00041955316555686295 0.01453420240432024 0.029990728944540024
+0.056647419929504395 0.00457494892179966 0.0010957220802083611 0.012727439403533936 0.0345778688788414
+0.05534125119447708 0.0042609525844454765 0.00170894677285105 0.014277017675340176 0.04819699376821518
+0.053101420402526855 9.4662478659302e-05 0.0003207987465430051 0.020021196454763412 0.06933460384607315
+0.06415422260761261 0.003080695401877165 0.0018818216631188989 0.026123758405447006 0.10739950090646744
+0.05545283854007721 0.0034846412017941475 0.001832573558203876 0.02719421498477459 0.07836505025625229
+0.04861653968691826 0.00021576230938080698 0.0003135453735012561 0.02253374643623829 0.04661080613732338
+0.041987210512161255 0.0012456921394914389 0.0009640580974519253 0.01454098615795374 0.02765485644340515
+0.05427240952849388 0.0021076053380966187 0.0011166725307703018 0.034065328538417816 0.05289376899600029
+0.055985115468502045 0.0006453262758441269 0.0005680828471668065 0.010565333999693394 0.03553790599107742
+0.045264892280101776 0.0002449091989547014 0.00041920688818208873 0.01657882332801819 0.04499254375696182
+0.05470825731754303 0.0014796745963394642 0.0007401269976980984 0.030097337439656258 0.06377388536930084
+0.058080192655324936 0.0013337292475625873 0.0010075380560010672 0.031160365790128708 0.10111900418996811
+0.06649420410394669 0.00020662984752561897 0.0007247352623380721 0.02262645959854126 0.10694676637649536
+0.04557899385690689 0.0003842301666736603 0.0007834522402845323 0.05753099545836449 0.041432514786720276
+0.053460635244846344 0.0005808805581182241 0.0004527296987362206 0.017891131341457367 0.033618222922086716
+0.061688486486673355 0.00030996924033388495 0.000552066951058805 0.025112561881542206 0.05777273699641228
+0.06291499733924866 0.00024024094454944134 0.0007450013072229922 0.022946206852793694 0.0786178708076477
+0.047126397490501404 0.00010785657650558278 0.0006720086093991995 0.027692532166838646 0.03212084248661995
+0.05829266458749771 0.00022143313253764063 0.00043351517524570227 0.026545792818069458 0.05374656245112419
+0.04394683986902237 0.00027937712729908526 0.00042763131204992533 0.01717003807425499 0.04064876213669777
+0.06333474069833755 9.976215369533747e-05 0.0003372841456439346 0.019159484654664993 0.08718723058700562
+0.04480535537004471 0.00022539799101650715 0.000926441396586597 0.03180708363652229 0.05154148116707802
+0.04256414994597435 9.262277308152989e-05 0.0003160084306728095 0.022766200825572014 0.04776940122246742
+0.05094751715660095 0.0005290876142680645 0.0005360011709854007 0.01754186861217022 0.043979447335004807
+0.051654599606990814 0.00029034927138127387 0.0006339991232380271 0.012823031283915043 0.060550492256879807
+0.04556676745414734 0.0003276906500104815 0.0002986988110933453 0.026393407955765724 0.047980669885873795
+0.052278924733400345 0.0007836202858015895 0.0005974636878818274 0.03298575058579445 0.06320447474718094
+0.04201433062553406 9.756485087564215e-05 0.0002523506118450314 0.016329597681760788 0.038641467690467834
+0.0474354550242424 0.0004054255550727248 0.001281052129343152 0.013340533711016178 0.059195440262556076
+0.053358182311058044 0.00010065035166917369 0.0004190981271676719 0.013260435312986374 0.06919080018997192
+0.04666725546121597 0.0005197647260501981 0.0008096994133666158 0.01917479746043682 0.05381528288125992
+0.05406150966882706 0.0003518183948472142 0.0009701001108624041 0.01924511045217514 0.06368561089038849
+0.0420541912317276 0.00026462681125849485 0.00026324420468881726 0.014683065004646778 0.03378256782889366
+0.05200472101569176 0.0007941098301671445 0.0005813786410726607 0.01613440178334713 0.04941720515489578
+0.04212021827697754 0.00016945773677434772 0.0002675392315723002 0.01669382117688656 0.040745120495557785
+0.046869054436683655 0.0003030523657798767 0.0009510366944596171 0.020289044827222824 0.05285964533686638
+0.042898714542388916 0.0002871291944757104 0.0006019285065121949 0.016863133758306503 0.05566294491291046
+0.03871341422200203 7.699732668697834e-05 0.0003754050412680954 0.019428176805377007 0.03660988062620163
+0.034989092499017715 0.00021120806923136115 0.0007693859515711665 0.016929665580391884 0.030353344976902008
+0.05919722467660904 0.00011244157940382138 0.0004999861703254282 0.020405307412147522 0.037406716495752335
+0.057623617351055145 0.0003485018969513476 0.0007714275852777064 0.023799754679203033 0.0500149168074131
+0.0490197092294693 0.00024339034280274063 0.0008257678709924221 0.020568549633026123 0.0432116873562336
+0.04379378631711006 0.00010352252138545737 0.0003909906663466245 0.01758168637752533 0.056758660823106766
+0.039023250341415405 0.00028360990108922124 0.0005785723915323615 0.021415431052446365 0.03465447202324867
+0.04170839488506317 0.00024114458938129246 0.00037215714110061526 0.024428237229585648 0.046682123094797134
+0.03934982791543007 0.00010503056546440348 0.0004135614144615829 0.020522277802228928 0.02960163913667202
+0.040480829775333405 9.285692794946954e-05 0.000301328458590433 0.020056791603565216 0.03735320642590523
+0.04144129902124405 0.0001321511808782816 0.00022475497098639607 0.01565912738442421 0.028055129572749138
+0.047007933259010315 0.00016461103223264217 0.0002505906450096518 0.013490698300302029 0.04035741835832596
+0.05635268986225128 0.00016523947124369442 0.00020727794617414474 0.012670662254095078 0.048855431377887726
+0.04047655314207077 0.00011082526907557622 0.00024375924840569496 0.014219816774129868 0.02127237245440483
+0.0520760603249073 0.0001598290546098724 0.00027177410083822906 0.027260297909379005 0.07235202938318253
+0.05609571933746338 0.00021465604368131608 0.00038349590613506734 0.021236563101410866 0.09877508133649826
+0.041447073221206665 0.00023892291937954724 0.00035143745481036603 0.029500378295779228 0.05238678678870201
+0.04676690697669983 0.00020661711459979415 0.0003814997908193618 0.02128424309194088 0.05412597581744194
+0.04175110533833504 0.00014731785631738603 0.000636643439065665 0.01635272055864334 0.01976911909878254
+0.04167851060628891 0.0004892586730420589 0.00021043579909019172 0.01166620384901762 0.028341984376311302
+0.06417620927095413 0.0014815559843555093 0.0007378449663519859 0.019400017336010933 0.04521946609020233
+0.047137290239334106 0.0006041984888724983 0.0008713173447176814 0.01904204860329628 0.04010508581995964
+0.06471562385559082 0.00041560520185157657 0.0005329771665856242 0.02128213457763195 0.1195734292268753
+0.051027633249759674 0.001813029171898961 0.0007949940045364201 0.03387090563774109 0.07820158451795578
+0.05238921195268631 0.0015887205954641104 0.0009300185483880341 0.022068142890930176 0.06902048736810684
+0.05537061393260956 0.00014511257177218795 0.00038451081491075456 0.019590632990002632 0.044965848326683044
+0.05435706302523613 0.001198992133140564 0.0005266795051284134 0.010884384624660015 0.04903100058436394
+0.044008925557136536 0.0019623609259724617 0.0009412092040292919 0.012370575219392776 0.03594566881656647
+0.04018862545490265 0.0003364273579791188 0.0004462892538867891 0.013908816501498222 0.022538529708981514
+0.038299448788166046 0.0006350624025799334 0.0006876391125842929 0.021383486688137054 0.03851837292313576
+0.0480399951338768 0.0013310902286320925 0.0011131585342809558 0.0176249872893095 0.06461307406425476
+0.04134370759129524 0.00022142662783153355 0.0003035719273611903 0.02141285128891468 0.046945970505476
+0.04861389845609665 0.00028667825972661376 0.000787685508839786 0.019276488572359085 0.048648733645677567
+0.04606357961893082 0.00030690684798173606 0.00027372787008062005 0.017462214455008507 0.04569714888930321
+0.04707500338554382 8.987717592390254e-05 0.0003089272358920425 0.014875771477818489 0.02625797502696514
+0.04017440229654312 6.51225054753013e-05 0.0003296912764199078 0.013376654125750065 0.031497519463300705
+0.03898315876722336 0.00011604864994296804 0.00033594368142075837 0.021740930154919624 0.026729904115200043
+0.04849078878760338 6.823927105870098e-05 0.0004971424932591617 0.04964301735162735 0.04875534772872925
+0.046659041196107864 0.0002212639810750261 0.00027611551922746 0.020061245188117027 0.07252276688814163
+0.06035618111491203 0.0003064997144974768 0.000361895770765841 0.01933364011347294 0.09085816890001297
+0.0439603328704834 6.672029849141836e-05 0.00031501983175985515 0.024150390177965164 0.039404138922691345
+0.03277501463890076 0.00015809631440788507 0.0002935278753284365 0.021158600226044655 0.01899133250117302
+0.04834965616464615 0.00016565763507969677 0.00026072587934322655 0.012901222333312035 0.04519536718726158
+0.03598889708518982 0.0002171462110709399 0.00023221108131110668 0.013342438265681267 0.03752202168107033
+0.049319248646497726 0.00027940841391682625 0.0005458012456074357 0.02235005982220173 0.051323648542165756
+0.043496206402778625 7.337995339185e-05 0.0004633133939933032 0.026756806299090385 0.04626375809311867
+0.046894341707229614 0.000799918663688004 0.00041139780660159886 0.014802863821387291 0.05791936814785004
+0.04126473516225815 0.0005807565758004785 0.00026901805540546775 0.014039522968232632 0.03901882469654083
+0.0505593940615654 0.00012384681031107903 0.0009102756739594042 0.020345665514469147 0.05058837682008743
+0.051017969846725464 0.0004730221116915345 0.0015136412112042308 0.02625175751745701 0.0691801905632019
+0.04280808940529823 0.00010366593778599054 0.0003784492437262088 0.022756708785891533 0.0441531166434288
+0.04874887689948082 0.0006174075533635914 0.0013804757036268711 0.017339641228318214 0.0569668710231781
+0.048272427171468735 0.00021711802401114255 0.0011988766491413116 0.014530369080603123 0.05748135969042778
+0.044513553380966187 0.0005318375770002604 0.00030725300894118845 0.020510906353592873 0.048016343265771866
+0.04622335731983185 0.0011477468069642782 0.00047531924792565405 0.019157398492097855 0.03987334296107292
+0.050883084535598755 0.0001785036874935031 0.00024676090106368065 0.019095079973340034 0.048774778842926025
+0.05097556114196777 0.00019129693100694567 0.0004996770876459777 0.014013740234076977 0.051787473261356354
+0.0425741970539093 0.0001880563941085711 0.00047996058128774166 0.02392392046749592 0.024375513195991516
+0.0430607944726944 0.00010275846580043435 0.0005856320494785905 0.019318122416734695 0.044475145637989044
+0.053263451904058456 0.00011680758325383067 0.0008014913182705641 0.02004324086010456 0.07004445046186447
+0.056287869811058044 0.00022108718985691667 0.0003802456194534898 0.018937939777970314 0.06415413320064545
+0.05666494742035866 0.0003512321272864938 0.00041459291242063046 0.02086237445473671 0.0636894553899765
+0.04606962203979492 0.00014646595809608698 0.00043088558595627546 0.017953451722860336 0.05169016867876053
+0.04398631677031517 0.00013249846233520657 0.0003361844574101269 0.018993807956576347 0.042729269713163376
+0.04955713450908661 0.00035987311275675893 0.000295511563308537 0.015116997994482517 0.057457584887742996
+0.04733699560165405 0.00037999582127667964 0.0002784235985018313 0.012946819886565208 0.06412792205810547
+0.051254160702228546 0.00021501764422282577 0.00035692533128894866 0.020333776250481606 0.03659995645284653
+0.056067585945129395 8.939045801525936e-05 0.00032397761242464185 0.01413395069539547 0.06949685513973236
+0.04963834583759308 9.620202763471752e-05 0.00029464150429703295 0.01808062754571438 0.06966539472341537
+0.04202355444431305 0.00013094511814415455 0.00033776243799366057 0.01420146506279707 0.03736957907676697
+0.048635222017765045 0.00013930605200584978 0.0003719259984791279 0.019628629088401794 0.04310347139835358
+0.04793313145637512 0.0002804460527841002 0.00039438289240933955 0.013199425302445889 0.04152863845229149
+0.04280207306146622 0.0001737958227749914 0.000363736180588603 0.018802756443619728 0.053274087607860565
+0.05326418578624725 0.0001540974626550451 0.0003809620684478432 0.018367303535342216 0.07578624784946442
+0.038187261670827866 0.00010150115849683061 0.000562779139727354 0.02156711369752884 0.03886331617832184
+0.03109230101108551 9.516182035440579e-05 0.00045585009502246976 0.014027736149728298 0.021844878792762756
+0.046334873884916306 0.0001647588360356167 0.0004847033997066319 0.014953519217669964 0.03729129582643509
+0.05088646337389946 0.0005661551840603352 0.00039123010355979204 0.021084178239107132 0.04122326895594597
+0.059640754014253616 0.00041626859456300735 0.0007189905736595392 0.018827196210622787 0.07679850608110428
+0.04194796457886696 9.724109986564144e-05 0.0005159295978955925 0.02652631141245365 0.049422651529312134
+0.04842798039317131 0.0007827847148291767 0.001038213144056499 0.028607899323105812 0.0703549012541771
+0.0404585525393486 0.0006707462598569691 0.0006546536460518837 0.02194291353225708 0.03864453360438347
+0.037516556680202484 9.716466593090445e-05 0.0003420012362767011 0.01723761111497879 0.026326078921556473
+0.05147685110569 0.0007180980173870921 0.0007751449593342841 0.018244991078972816 0.03881964087486267
+0.03398619592189789 0.0006666429690085351 0.0008951362106017768 0.014290235936641693 0.01270695123821497
+0.04995296895503998 9.001532453112304e-05 0.0003333288768772036 0.015594011172652245 0.062007348984479904
+0.054084230214357376 0.0005964112933725119 0.0003123401547782123 0.01636691950261593 0.08993502706289291
+0.038572993129491806 0.0006240400834940374 0.0003473498800303787 0.025465888902544975 0.04568615183234215
+0.05865313112735748 0.00019865580543410033 0.00033158567384816706 0.021076155826449394 0.10332522541284561
+0.03892526775598526 9.549756214255467e-05 0.00025634883786551654 0.015280770137906075 0.022461438551545143
+0.054541148245334625 0.0003497689904179424 0.00019891199190169573 0.011720112524926662 0.05859297141432762
+0.05019726976752281 0.00040438008727505803 0.00019390527450013906 0.012678771279752254 0.02572837471961975
+0.040876444429159164 0.00016685279842931777 0.00019301944121252745 0.013552000746130943 0.0346338152885437
+0.03654496371746063 7.953925523906946e-05 0.00024021716671995819 0.016982462257146835 0.03953758254647255
+0.0511263906955719 0.00025266085867770016 0.0004366478242445737 0.019150445237755775 0.07734522223472595
+0.053082142025232315 0.00026215973775833845 0.0006898969877511263 0.02450012043118477 0.066392682492733
+0.04962801933288574 9.902760211843997e-05 0.0004631518677342683 0.02053784765303135 0.06575203686952591
+0.04448424279689789 0.00010817767906701192 0.0005188685026951134 0.021573420614004135 0.04075705632567406
+0.04941311478614807 0.00016797217540442944 0.00028101252974011004 0.019009387120604515 0.049012258648872375
+0.04874536767601967 0.0002569358330219984 0.00040716995135881007 0.015647554770112038 0.04940437898039818
+0.06316929310560226 0.0002482692070771009 0.0004966053529642522 0.03243101015686989 0.08086176961660385
+0.04614236205816269 8.65654947119765e-05 0.0005416623316705227 0.0178066436201334 0.04626196250319481
+0.05579838901758194 0.0003881159645970911 0.0005007837316952646 0.017776042222976685 0.07006525993347168
+0.03384663909673691 0.0003930515085812658 0.00029750828980468214 0.01766224019229412 0.025946540758013725
+0.04775070399045944 7.958710193634033e-05 0.0002462641859892756 0.013233838602900505 0.0567624494433403
+0.059870459139347076 0.00017593579832464457 0.028047936037182808 0.016222402453422546 0.048593807965517044
+0.03293168544769287 0.00013249757466837764 0.00026228983188048005 0.01652970165014267 0.0273911003023386
+0.06281153857707977 7.251116767292842e-05 0.0003812953655142337 0.014944348484277725 0.09968111664056778
+0.04083893820643425 5.7740478951018304e-05 0.00029759175959043205 0.017881356179714203 0.0330098494887352
+0.035316623747348785 7.118929352145642e-05 0.0002771184081211686 0.01568668708205223 0.020386632531881332
+0.04666747525334358 9.987732482841238e-05 0.00023920915555208921 0.01621692068874836 0.057850319892168045
+0.04157688841223717 0.00012210964632686228 0.0003968172532040626 0.018714647740125656 0.047498732805252075
+0.04781300947070122 7.961975643411279e-05 0.00032045290572568774 0.030364029109477997 0.05658358708024025
+0.03958440572023392 0.0002859729283954948 0.0007960073999129236 0.03249582648277283 0.029483409598469734
+0.058700721710920334 0.000322048959787935 0.0006996268639340997 0.020931435748934746 0.055028900504112244
+0.05233165994286537 0.00010090599971590564 0.0004046002868562937 0.01774951070547104 0.04135189577937126
+0.047308892011642456 0.0006707776337862015 0.0008819426875561476 0.02488882653415203 0.05957932025194168
+0.04764788597822189 0.00032516103237867355 0.0006215593311935663 0.028743447735905647 0.0524635836482048
+0.04958777129650116 0.00021473861124832183 0.000463977426989004 0.020655637606978416 0.04749278724193573
+0.04646501690149307 0.00048609593068249524 0.0007171914330683649 0.016742222011089325 0.06396681815385818
+0.04618650674819946 9.679852519184351e-05 0.0002489935723133385 0.02156897820532322 0.050837643444538116
+0.04460864141583443 0.00033449140028096735 0.0007439503096975386 0.012160422280430794 0.04005606472492218
+0.052232179790735245 0.00025066558737307787 0.0007868373650126159 0.0167509987950325 0.06872778385877609
+0.049980536103248596 0.00016557866183575243 0.000243474860326387 0.015186182223260403 0.06435464322566986
+0.041628219187259674 0.0005765896639786661 0.00032918743090704083 0.015976741909980774 0.02041781321167946
+0.04525972530245781 0.00030174560379236937 0.00044638325925916433 0.011691144667565823 0.041540782898664474
+0.03933240845799446 0.0001544550759717822 0.0003700086090248078 0.01620732806622982 0.03921680524945259
+0.04034270718693733 0.00037739609251730144 0.0004380463215056807 0.018609793856739998 0.03116658143699169
+0.046786583960056305 0.00015079417789820582 0.00039819764788262546 0.017392747104167938 0.053965069353580475
+0.04151429608464241 0.000120206517749466 0.0008321929490193725 0.018047332763671875 0.0527830570936203
+0.04950729012489319 9.892550588119775e-05 0.0008655494311824441 0.01889077201485634 0.05368237569928169
+0.04460406303405762 0.000192816267372109 0.0005103360745124519 0.020696694031357765 0.052027154713869095
+0.05590125173330307 0.0003347214369568974 0.0005159639986231923 0.01779898628592491 0.03646703436970711
+0.03654797375202179 0.00016743085870984942 0.0005841145175509155 0.011631667613983154 0.02561810053884983
+0.030749265104532242 0.00010412967822048813 0.0006291173049248755 0.015655264258384705 0.028310788795351982
+0.04612148553133011 0.0003088669618591666 0.0005648363730870187 0.026773054152727127 0.061461322009563446
+0.05061928927898407 0.000799689325504005 0.0005233484553173184 0.027945592999458313 0.06865710765123367
+0.04542887955904007 0.0006065754569135606 0.0004990017623640597 0.02116755023598671 0.058410800993442535
+0.03522033244371414 8.317542233271524e-05 0.00026876688934862614 0.013937200419604778 0.02354234643280506
+0.050521232187747955 0.0011563189327716827 0.00044565877760760486 0.01782005839049816 0.05947285145521164
+0.04458218067884445 0.001289963023737073 0.0006195165915414691 0.017410043627023697 0.02291880175471306
+0.04621896147727966 6.94659975124523e-05 0.0004059445345774293 0.01647331193089485 0.0335136316716671
+0.04779564589262009 0.0007659331313334405 0.0008025292190723121 0.01600305177271366 0.06164766103029251
+0.038686685264110565 0.0002676500007510185 0.0002976803225465119 0.01704459637403488 0.05094914883375168
+0.0552300363779068 0.00012585641525220126 0.0009234954486601055 0.023752311244606972 0.06954559683799744
+0.04730141535401344 9.956991561921313e-05 0.0006183803197927773 0.030715100467205048 0.04986957460641861
+0.04627535864710808 0.0002388934517512098 0.00030881576822139323 0.012625313363969326 0.04357065632939339
+0.049556903541088104 0.0001842294877860695 0.0003946185752283782 0.014527792111039162 0.045135244727134705
+0.04872230440378189 0.0001341294264420867 0.00030699640046805143 0.019820652902126312 0.05256306380033493
+0.03516384959220886 0.00031046936055645347 0.00022174492187332362 0.01712595671415329 0.02301168441772461
+0.03334477171301842 0.00019901309860870242 0.00023955156211741269 0.013155665248632431 0.02965151146054268
+0.04597854986786842 0.00016205436259042472 0.0002527103351894766 0.030966609716415405 0.05537394434213638
+0.042876459658145905 0.0001380488247377798 0.0005940871778875589 0.022944267839193344 0.04782228544354439
+0.03496925160288811 0.00023172018700279295 0.00031550758285447955 0.023992031812667847 0.029499435797333717
+0.037318259477615356 0.0009093151893466711 0.0005672675906680524 0.017920570448040962 0.03872574865818024
+0.05055774003267288 0.000364943640306592 0.0003787277964875102 0.01449122279882431 0.06782670319080353
+0.055290043354034424 0.0002976049145217985 0.0006541520706377923 0.02298717200756073 0.06064459681510925
+0.039281442761421204 0.0004989005392417312 0.0007457975880242884 0.02650555409491062 0.03530228137969971
+0.048564210534095764 8.069612522376701e-05 0.000617660058196634 0.02047317661345005 0.04752914607524872
+0.046406298875808716 0.000363981380360201 0.0015743181575089693 0.014203292317688465 0.049828846007585526
+0.03461800515651703 0.00010353914694860578 0.0003789479669649154 0.011255599558353424 0.04024212434887886
+0.049111030995845795 0.0009056192939169705 0.0008193701505661011 0.019594941288232803 0.03137121722102165
+0.04615755379199982 0.0003353313950356096 0.0005238708690740168 0.018663372844457626 0.029711982235312462
+0.05048546940088272 0.00028400978771969676 0.0003152509161736816 0.016380395740270615 0.07916445285081863
+0.03870992362499237 0.0005829301080666482 0.0004671995993703604 0.017250152304768562 0.04102185741066933
+0.03878086060285568 0.00012872961815446615 0.00039164372719824314 0.023269573226571083 0.03146281838417053
+0.03924449160695076 7.388847734546289e-05 0.00041692121885716915 0.01622234471142292 0.038220565766096115
+0.04177910089492798 0.00010152484173886478 0.00030633845017291605 0.01079540979117155 0.04794461280107498
+0.040280114859342575 0.00010247564205201343 0.0002463298151269555 0.014305326156318188 0.031488921493291855
+0.04740426689386368 9.63080019573681e-05 0.0003474602708593011 0.018929079174995422 0.04909713938832283
+0.03669304400682449 0.00010589996963972226 0.0003679616202134639 0.021311309188604355 0.04305986687541008
+0.044141799211502075 0.00018902415467891842 0.0005995373940095305 0.013605231419205666 0.05940886586904526
+0.04463652893900871 5.980227069812827e-05 0.00042136473348364234 0.03680918738245964 0.04676062986254692
+0.03953846916556358 0.00032040386577136815 0.00023024153779260814 0.014928906224668026 0.03192439675331116
+0.03728476166725159 0.00032210673089139163 0.0003129890828859061 0.014780182391405106 0.036972444504499435
+0.042547523975372314 6.750904867658392e-05 0.0002854867198038846 0.012061984278261662 0.02047760784626007
+0.04891997203230858 0.00019722370780073106 0.000294944184133783 0.014781424775719643 0.031249502673745155
+0.055410075932741165 0.00023331843840423971 0.00030244587105698884 0.014322950504720211 0.05314657464623451
+0.04259981960058212 0.0001405656075803563 0.0004256684915162623 0.01973658986389637 0.052892062813043594
+0.045194514095783234 0.0001348409423371777 0.0009532158146612346 0.02049841545522213 0.0487748384475708
+0.03791509568691254 0.0001392815902363509 0.00038753615808673203 0.01725972630083561 0.04219908267259598
+0.04006143659353256 0.00033596300636418164 0.000533512094989419 0.012907123193144798 0.03442351147532463
+0.04994327574968338 0.00012632380821742117 0.00027214648434892297 0.021317649632692337 0.07152488827705383
+0.045893795788288116 0.0005115030216984451 0.0004324868496041745 0.023006141185760498 0.06574906408786774
+0.04358575493097305 0.00024052252410911024 0.0003518353041727096 0.019989585503935814 0.029775775969028473
+0.03657427430152893 0.00016180702368728817 0.0007090282742865384 0.011792084202170372 0.04619480296969414
+0.047486111521720886 8.983511361293495e-05 0.00036228031967766583 0.012643391266465187 0.05718066170811653
+0.04737621545791626 0.0002209805097663775 0.0003910962550435215 0.026458969339728355 0.02474857121706009
+0.04016261547803879 0.00013438994938042015 0.0003912587999366224 0.014252165332436562 0.03874240070581436
+0.04518686234951019 0.00017189154459629208 0.00015555098070763052 0.01069786585867405 0.0570724681019783
+0.03356561064720154 0.00014420576917473227 0.000296771468129009 0.015569651499390602 0.029491089284420013
+0.031742654740810394 0.00011781977082137018 0.0006520590977743268 0.02464296668767929 0.026552490890026093
+0.03821422532200813 0.00024201998894568533 0.0002593390236143023 0.022473298013210297 0.02663777396082878
+0.04186077043414116 0.0002681734331417829 0.0005999885615892708 0.015452375635504723 0.04267299175262451
+0.03568820655345917 0.00016703586152289063 0.00032539592939428985 0.013816284015774727 0.02411210909485817
+0.04350268095731735 0.0010389925446361303 0.0006662337691523135 0.020595090463757515 0.04431942105293274
+0.0397929847240448 0.0003591422864701599 0.0003854332899209112 0.024990366771817207 0.04159562662243843
+0.05321178585290909 0.000397214840631932 0.0011548862094059587 0.016511332243680954 0.08148755878210068
+0.050657376646995544 0.000575487210880965 0.0016676199156790972 0.015696343034505844 0.052813947200775146
+0.05419951677322388 0.00015634768351446837 0.0002326418034499511 0.02472950704395771 0.04948841407895088
+0.0517861433327198 0.0011182993184775114 0.0010560400551185012 0.013367327861487865 0.03279925510287285
+0.0369216687977314 0.00033012175117619336 0.0006857919506728649 0.014970271848142147 0.0365133099257946
+0.028427444398403168 0.0004305476904846728 0.00033532505040057003 0.01627415232360363 0.02154889889061451
+0.04401422291994095 0.0010886642849072814 0.0006173203000798821 0.022722143679857254 0.05378362536430359
+0.03889453038573265 0.000435082329204306 0.0002737250761128962 0.033678870648145676 0.03451455011963844
+0.04231877252459526 0.00013147108256816864 0.0002997481788042933 0.015899691730737686 0.05817969888448715
+0.03656962513923645 0.0003895616973750293 0.00018140663451049477 0.014249729923903942 0.022842390462756157
+0.04222848266363144 0.0008164001628756523 0.00017463442054577172 0.011460380628705025 0.024771517142653465
+0.04242268204689026 0.00038271566154435277 0.0002209916856372729 0.015378016978502274 0.008761425502598286
+0.05706740915775299 0.00011517589155118912 0.0002529794874135405 0.01494641788303852 0.057561881840229034
+0.04184588044881821 0.00027361290995031595 0.00017930599278770387 0.012290685437619686 0.05720818415284157
+0.06459963321685791 0.00013618830416817218 0.00039595089037902653 0.020676864311099052 0.12582409381866455
+0.04357337951660156 0.0001117309438996017 0.000590067938901484 0.02965870313346386 0.05850304290652275
+0.04247388243675232 0.0003889034269377589 0.0003462046734057367 0.01751481555402279 0.05915118753910065
+0.050856880843639374 0.0002561220608185977 0.0009708622237667441 0.014031533151865005 0.049481626600027084
+0.061596814543008804 0.0004717206466011703 0.0001821186306187883 0.012798607349395752 0.03548272326588631
+0.04193504899740219 0.001890572370029986 0.0007777176215313375 0.01523631252348423 0.026375146582722664
+0.036368388682603836 0.0004859486944042146 0.0004255660460330546 0.017955316230654716 0.037270687520504
+0.04236491769552231 0.0006288814474828541 0.000694558781106025 0.01932057924568653 0.06107475608587265
+0.04042920470237732 0.0014377094339579344 0.0013325802283361554 0.020232927054166794 0.04962662607431412
+0.03739534318447113 0.00028603486134670675 0.0007303840829990804 0.029497874900698662 0.05570673197507858
+0.050255514681339264 0.00024805148132145405 0.0007079236092977226 0.02650279738008976 0.06579962372779846
+0.03051788918673992 0.0005317136528901756 0.0005399896763265133 0.011475583538413048 0.020279772579669952
+0.04224223643541336 0.0002717618481256068 0.00035048884456045926 0.013135953806340694 0.021178346127271652
+0.04159165918827057 9.666655387263745e-05 0.0002907991874963045 0.017689716070890427 0.03436784818768501
+0.06002529338002205 0.00020419327483978122 0.0002742809592746198 0.0174631979316473 0.06493135541677475
+0.03898200765252113 0.00014717638259753585 0.00023353690630756319 0.01273210346698761 0.038070883601903915
+0.037579894065856934 9.14642951101996e-05 0.0003860057913698256 0.012506847269833088 0.032532788813114166
+0.03690694272518158 0.00010370163363404572 0.00047063210513442755 0.013170979917049408 0.04768241569399834
+0.048418186604976654 0.00010651195771060884 0.0002547872718423605 0.03862827271223068 0.06515828520059586
+0.04170041158795357 0.00023402216902468354 0.0004306159680709243 0.017149873077869415 0.03644504398107529
+0.037064410746097565 0.00016582466196268797 0.0008213077671825886 0.009691311977803707 0.036667145788669586
+0.033474892377853394 0.00030441401759162545 0.0002767541736830026 0.009422605857253075 0.03087092936038971
+0.0396532341837883 0.001067486358806491 0.000410425360314548 0.01090981811285019 0.043268833309412
+0.04794224351644516 0.00033676126622594893 0.00030092254746705294 0.020030589774250984 0.0410042330622673
+0.04640026390552521 0.0004132547474000603 0.0007294168462976813 0.01918785832822323 0.05297602713108063
+0.043879732489585876 0.0009494617697782815 0.0010463185608386993 0.017364265397191048 0.05579464137554169
+0.04321271926164627 0.00010689205373637378 0.0005678262677974999 0.025153951719403267 0.04856787621974945
+0.053650274872779846 0.0003703027032315731 0.0007212780183181167 0.023190196603536606 0.06765040755271912
+0.034095827490091324 0.00025348825147375464 0.0003458250139374286 0.016346484422683716 0.024216128513216972
+0.05452197045087814 6.539409514516592e-05 0.00033477554097771645 0.014762519858777523 0.02264082431793213
+0.03600785881280899 6.976806616876274e-05 0.00019095384050160646 0.013730517588555813 0.031931765377521515
+0.034286417067050934 0.0001232129434356466 0.0005484448047354817 0.020704839378595352 0.03225399926304817
+0.04776603356003761 0.00010527107951929793 0.0005785279208794236 0.01857561059296131 0.07375045120716095
+0.048520635813474655 0.0003312512126285583 0.00043536126031540334 0.021707117557525635 0.05108802765607834
+0.038349926471710205 0.0005654720007441938 0.00043301130062900484 0.015365292318165302 0.03715357929468155
+0.051425959914922714 9.748549200594425e-05 0.0002060667466139421 0.016916519030928612 0.05984867364168167
+0.045690856873989105 0.00033890397753566504 0.00043243265827186406 0.0149710513651371 0.039009444415569305
+0.03947645053267479 0.0003784378059208393 0.00027165914070792496 0.014987105503678322 0.04773218557238579
+0.047622375190258026 6.277146894717589e-05 0.0004542810202110559 0.014381442219018936 0.06885947287082672
+0.027121631428599358 0.00011697092850226909 0.000608585134614259 0.029627101495862007 0.01745479740202427
+0.03716609999537468 7.512830052291974e-05 0.0004113356990274042 0.02227230928838253 0.016462454572319984
+0.042951248586177826 0.0001401963527314365 0.00022671320766676217 0.015667853876948357 0.04845167696475983
+0.044423386454582214 0.00031116121681407094 0.0003777859383262694 0.017297720536589622 0.051457729190588
+0.0527816116809845 9.910454537020996e-05 0.0003079073503613472 0.018441783264279366 0.0713198110461235
+0.032259512692689896 0.0004239249392412603 0.00024910556385293603 0.016242844983935356 0.01545530091971159
+0.03829669952392578 0.000531988509465009 0.00036396057112142444 0.015002320520579815 0.04589622840285301
+0.03961847722530365 9.751849574968219e-05 0.00024121464230120182 0.011583779938519001 0.0280532818287611
+0.03790324181318283 0.0006219971110112965 0.0010053599253296852 0.011925852857530117 0.040490634739398956
+0.04836473986506462 0.0002474669017829001 0.0005277530290186405 0.015302061103284359 0.05689449980854988
+0.046353697776794434 0.00030236394377425313 0.0004150049644522369 0.013100716285407543 0.04316987097263336
+0.04346085339784622 0.0005881632096134126 0.0005998255801387131 0.016004959121346474 0.039761364459991455
+0.03766922280192375 0.00012389550101943314 0.0005672244005836546 0.017504289746284485 0.04982800781726837
+0.04905422404408455 0.0002754140004981309 0.0006346802692860365 0.02336544543504715 0.08000436425209045
+0.030809763818979263 0.00018916858243755996 0.0006315811770036817 0.01755371503531933 0.020426806062459946
+0.054070692509412766 0.00011063079000450671 0.000745335768442601 0.013817325234413147 0.08030752092599869
+0.03475947305560112 0.00011435744090704247 0.0007070715073496103 0.017524568364024162 0.02745189145207405
+0.047868695110082626 0.00014560279669240117 0.0003278460935689509 0.014474084600806236 0.043489597737789154
+0.044242583215236664 0.0003339141549076885 0.0003918500733561814 0.014596573077142239 0.026955336332321167
+0.04920293763279915 0.0001925042743096128 0.0003641583607532084 0.011898106895387173 0.032350655645132065
+0.04721513018012047 0.00012722000246867537 0.0002399489312665537 0.012642228975892067 0.04728970676660538
+0.03342341259121895 0.0001902619405882433 0.0005097182001918554 0.013482130132615566 0.017483536154031754
+0.038155656307935715 0.00016501633217558265 0.00090141681721434 0.023824170231819153 0.04838263615965843
+0.04887780174612999 0.0002901703701354563 0.0006319936364889145 0.02680441178381443 0.07616760581731796
+0.03848357871174812 0.000543581263627857 0.00043267058208584785 0.02061648480594158 0.03270060941576958
+0.043486542999744415 0.00017172133084386587 0.0005533384974114597 0.016828319057822227 0.0525515191257
+0.03753150999546051 0.0004478935443330556 0.00042876467341557145 0.015882210806012154 0.03019683249294758
+0.03338048234581947 0.0012496723793447018 0.0002451865002512932 0.019128626212477684 0.01008888240903616
+0.037311673164367676 0.0005869517335668206 0.0002586250484455377 0.05068931356072426 0.016296738758683205
+0.04712647944688797 0.0001327316858805716 0.0006130874971859157 0.01463096309453249 0.05710967257618904
+0.040580347180366516 0.0002308625407749787 0.00045731652062386274 0.017669398337602615 0.038109369575977325
+0.04838092625141144 0.0003008131461683661 0.000405290222261101 0.024710604920983315 0.06088049337267876
+0.043596312403678894 0.00021516949345823377 0.0003723435802385211 0.020557953044772148 0.052382081747055054
+0.034699827432632446 0.00024630752159282565 0.0003012387896887958 0.011588292196393013 0.02516244538128376
+0.04486110061407089 0.00014069447934161872 0.0002137018454959616 0.01440915372222662 0.04535683989524841
+0.04354090988636017 0.0005877528456039727 0.00029637134866788983 0.013659253716468811 0.020231008529663086
+0.040029723197221756 0.0002550541248638183 0.00021958073193673044 0.01542797964066267 0.048187579959630966
+0.03319445997476578 0.00011082493438152596 0.0003462627646513283 0.01628042384982109 0.03540778532624245
+0.04072754830121994 0.00018900186114478856 0.00047897823969833553 0.019991273060441017 0.05090874060988426
+0.04800720885396004 0.0001509337016614154 0.0004018845793325454 0.024639826267957687 0.06426767259836197
+0.034299131482839584 0.00013378165022004396 0.0002366479893680662 0.017920121550559998 0.03318191319704056
+0.037873879075050354 0.0001289352949243039 0.00025962531799450517 0.012933562509715557 0.034123536199331284
+0.04890073090791702 0.00020090054022148252 0.000414277077652514 0.01095808856189251 0.022720905020833015
+0.03848021849989891 0.0003052517422474921 0.0005095459637232125 0.015824100002646446 0.03472979739308357
+0.03773253411054611 0.00018269172869622707 0.0003746073052752763 0.02039659582078457 0.04142913222312927
+0.03718966618180275 0.0006027363706380129 0.00036014922079630196 0.01749032363295555 0.052743956446647644
+0.03545248880982399 0.0003789575130213052 0.0005200337618589401 0.015273517929017544 0.03403310477733612
+0.04481104016304016 0.00017113902140408754 0.00037843186873942614 0.018769826740026474 0.04154214262962341
+0.03588680922985077 0.0004683040897361934 0.00027192573179490864 0.014265353791415691 0.024992497637867928
+0.0478799045085907 0.00033070065546780825 0.00045746599789708853 0.013071169145405293 0.03288145363330841
+0.04576032608747482 0.0001736870763124898 0.0022797705605626106 0.01788930781185627 0.05812292918562889
+0.04772456735372543 0.00011777938198065385 0.0004353740077931434 0.017314478754997253 0.07223832607269287
+0.0336783193051815 0.00010283775918651372 0.00039747232221998274 0.021511610597372055 0.03634041175246239
+0.03979257121682167 0.00014942897541914135 0.0002289754629600793 0.011775920167565346 0.046335071325302124
+0.032260969281196594 0.00014748443209100515 0.00021033352822996676 0.015436374582350254 0.01898202672600746
+0.031113401055336 0.0001007328464766033 0.0004761680611409247 0.012541680596768856 0.028426919132471085
+0.042036764323711395 0.00010582723916741088 0.0006541821057908237 0.01519523374736309 0.03084932267665863
+0.038071561604738235 0.00015432755753863603 0.0002661102917045355 0.022105339914560318 0.039756812155246735
+0.06180395185947418 0.00032070555607788265 0.0002309754490852356 0.015944018959999084 0.1087975949048996
+0.039181649684906006 0.00014777056640014052 0.0003437065752223134 0.020845573395490646 0.03726835921406746
+0.04183484613895416 0.00010718624980654567 0.0003113493730779737 0.015958299860358238 0.03239802271127701
+0.03494773060083389 0.0002948210167232901 0.00036584961344487965 0.01102394051849842 0.028140276670455933
+0.0301162488758564 0.0005188359646126628 0.00019066805543843657 0.011253000237047672 0.01918441243469715
+0.034364841878414154 0.00037820026045665145 0.0002453979104757309 0.015070540830492973 0.016489826142787933
+0.02908487617969513 0.00010017139720730484 0.0002961065911222249 0.017705079168081284 0.017626937478780746
+0.04329628497362137 0.0005281681660562754 0.000385093066142872 0.017293518409132957 0.05694231018424034
+0.057674944400787354 0.0005612653912976384 0.0003672384191304445 0.02604471519589424 0.08557464927434921
+0.031157420948147774 9.798325481824577e-05 0.0003091402177233249 0.01797928847372532 0.023115914314985275
+0.030732423067092896 0.00010242475400445983 0.0003008343919645995 0.013334639370441437 0.030690129846334457
+0.03619895875453949 0.0002942577993962914 0.00020509902969934046 0.013362089172005653 0.02697836048901081
+0.04435384273529053 0.0002847858122549951 0.000301755266264081 0.01165863685309887 0.05124155059456825
+0.043890580534935 6.071206007618457e-05 0.00023245529155246913 0.013748666271567345 0.022887932136654854
+0.0378761887550354 0.000292961165541783 0.00018306328274775296 0.014567560516297817 0.04484223201870918
+0.04051099717617035 0.00030826518195681274 0.0003045146004296839 0.012068461626768112 0.03387744724750519
+0.04128190875053406 8.017060463316739e-05 0.00033758379868231714 0.016138635575771332 0.04101767763495445
+0.036082495003938675 0.0001365714124403894 0.0006281495443545282 0.026332871988415718 0.04832968860864639
+0.04134272783994675 0.00011268419621046633 0.00032933251350186765 0.0213032066822052 0.05841119959950447
+0.04061233252286911 0.00031098141334950924 0.001146190450526774 0.012674292549490929 0.03139827400445938
+0.03826653212308884 6.154119182610884e-05 0.0007549045258201659 0.011816013604402542 0.03568975254893303
+0.03350222110748291 0.0010984087130054832 0.00018247334810439497 0.0145728699862957 0.01054803654551506
+0.04283124953508377 0.000982927973382175 0.0003002918674610555 0.011947574093937874 0.04306334629654884
+0.041881099343299866 5.376888657337986e-05 0.00024251019931398332 0.011023646220564842 0.04120080918073654
+0.04461502283811569 0.0005507972091436386 0.00033542001619935036 0.012251484207808971 0.054294660687446594
+0.04003491997718811 0.00044841787894256413 0.0002523764851503074 0.013691752217710018 0.05331987142562866
+0.03640422970056534 0.000114661299448926 0.0003974464489147067 0.015080628916621208 0.04278307408094406
+0.03760319948196411 7.481749344151467e-05 0.00040804612217471004 0.011288939043879509 0.049227070063352585
+0.03606434911489487 0.00011150459613418207 0.00042103687883354723 0.01268589124083519 0.03427613899111748
+0.04261608421802521 6.47433553240262e-05 0.0003321597178000957 0.014156234450638294 0.04518156126141548
+0.03922431170940399 7.690703205298632e-05 0.0003063617623411119 0.01943134143948555 0.04036850854754448
+0.03918202593922615 9.840486745815724e-05 0.0004931915318593383 0.01863085851073265 0.044144779443740845
+0.03534528240561485 0.00010067146649817005 0.0002728060935623944 0.014816651120781898 0.025753764435648918
+0.035085923969745636 0.0003972369013354182 0.0003308208251837641 0.012662319466471672 0.02545648254454136
+0.03309818357229233 0.0001584364945301786 0.00039825227577239275 0.015501484274864197 0.015431474894285202
+0.044005945324897766 0.00021890064817853272 0.0004974746843799949 0.019546952098608017 0.05763867497444153
+0.03666162118315697 0.0005298641044646502 0.00047438591718673706 0.021934937685728073 0.038232047110795975
+0.04575826972723007 0.0002189634251408279 0.0003573470457922667 0.020237386226654053 0.06097244471311569
+0.04716690257191658 0.00015674026508349925 0.0003279077063780278 0.013312785886228085 0.059835441410541534
+0.053217314183712006 0.0005149930366314948 0.00023420891375280917 0.012671126052737236 0.053818799555301666
+0.03662043809890747 0.00015919496945571154 0.00019702070858329535 0.013500794768333435 0.03552998974919319
+0.04593996703624725 8.030498429434374e-05 0.0002926443121396005 0.013717164285480976 0.055002935230731964
+0.02358497679233551 0.0001039390845107846 0.0002952592622023076 0.012930944561958313 0.015348896384239197
+0.038346726447343826 7.170259777922183e-05 0.00026426010299474 0.015116909518837929 0.02696080319583416
+0.040992047637701035 0.00010082894004881382 0.0005936106899753213 0.014803619123995304 0.05090982839465141
+0.04079563915729523 5.726216113544069e-05 0.00044304263428784907 0.013177181594073772 0.05041176453232765
+0.032143235206604004 0.00029883740353398025 0.00020365201635286212 0.015400602482259274 0.039378631860017776
+0.03852139040827751 0.00028225043206475675 0.0002693807182367891 0.022976089268922806 0.028070896863937378
+0.04290741682052612 5.700478141079657e-05 0.000416918279370293 0.012985439039766788 0.023020915687084198
+0.039683058857917786 0.0002852241159416735 0.0002550007193349302 0.010162887163460255 0.04679083451628685
+0.03189413994550705 0.0005459712701849639 0.0004441451164893806 0.015354132279753685 0.03011264093220234
+0.03165911138057709 7.266945613082498e-05 0.00034621969098225236 0.02399607188999653 0.009092112071812153
+0.039088573306798935 0.0006165591184981167 0.0016887311358004808 0.014484121464192867 0.036796923726797104
+0.04122287034988403 0.0002157391281798482 0.0012031561927869916 0.026874743402004242 0.04872762784361839
+0.042893391102552414 0.0006412427173927426 0.0005668283556587994 0.018133040517568588 0.04924541339278221
+0.03940335661172867 0.0008062769193202257 0.000865756708662957 0.01924906298518181 0.042071327567100525
+0.038631293922662735 0.0001185203727800399 0.0002859005471691489 0.016308359801769257 0.028764992952346802
+0.037027597427368164 0.0008055799989961088 0.000825054245069623 0.014154442586004734 0.03303854167461395
+0.04455241188406944 6.281716923695058e-05 0.00037695339415222406 0.012849261984229088 0.05287453159689903
+0.0326557382941246 0.0006246584234759212 0.0024512065574526787 0.021085355430841446 0.03035835176706314
+0.035908088088035583 7.270005880855024e-05 0.00077816128032282 0.02158271335065365 0.03980375826358795
+0.047019150108098984 0.001468671951442957 0.000787876604590565 0.02198123000562191 0.035202741622924805
+0.039554283022880554 0.0005071746418252587 0.000485696509713307 0.014965745620429516 0.04921843856573105
+0.03732459992170334 0.0006708041182719171 0.000653947121463716 0.020296700298786163 0.026238437741994858
+0.026861470192670822 0.0011727426899597049 0.0007304633618332446 0.018187204375863075 0.008373814634978771
+0.038887329399585724 8.328249532496557e-05 0.0004244180745445192 0.015867561101913452 0.039612870663404465
+0.03910864144563675 0.0004344603803474456 0.001207797322422266 0.019356228411197662 0.05763198435306549
+0.027686145156621933 9.326910367235541e-05 0.00047184518189169466 0.025267869234085083 0.014732690528035164
+0.03655597195029259 0.00013277442485559732 0.00021226181706879288 0.017335938289761543 0.0316474586725235
+0.046321071684360504 0.0001856183516792953 0.00015352218179032207 0.014773832634091377 0.05058871954679489
+0.045957401394844055 0.00016299978597089648 0.0001650407793931663 0.016884470358490944 0.04838259890675545
+0.031184915453195572 6.0765079979319125e-05 0.0003507820947561413 0.019574064761400223 0.02829292230308056
+0.05252864211797714 0.00024530317750759423 0.00019876255828421563 0.02186805009841919 0.04372232034802437
+0.03411879763007164 0.0003377409011591226 0.0002743960649240762 0.01241423562169075 0.01924932189285755
+0.04088502377271652 0.0001390402758261189 0.00033404381247237325 0.015726188197731972 0.038395460695028305
+0.037033140659332275 0.0009072211105376482 0.0021137245930731297 0.0181281715631485 0.034608565270900726
+0.030525026842951775 0.00012756054638884962 0.0012204161612316966 0.02345670945942402 0.02907000668346882
+0.049645502120256424 0.0018182828789576888 0.0009232256561517715 0.018862387165427208 0.08756531774997711
+0.048653922975063324 0.0015233919257298112 0.001451849122531712 0.015733787789940834 0.07468466460704803
+0.05481558293104172 0.00025350204668939114 0.0002752640575636178 0.013818728737533092 0.055305976420640945
+0.04855652153491974 0.002455405192449689 0.0007395308930426836 0.011782864108681679 0.03626453876495361
+0.04392180219292641 0.0012500612065196037 0.00043679377995431423 0.010276642628014088 0.03725406154990196
+0.03308556228876114 0.00013811764074489474 0.00033595814602449536 0.015830570831894875 0.022059142589569092
+0.04305988922715187 0.000914845964871347 0.0008187991916202009 0.023827310651540756 0.036027684807777405
+0.043123818933963776 0.0007314842659980059 0.0005917611997574568 0.019697710871696472 0.0565289631485939
+0.04999302327632904 0.00016216908989008516 0.00047747950884513557 0.023485872894525528 0.07581055164337158
+0.04524645581841469 0.00031871136161498725 0.0008570287609472871 0.026460964232683182 0.05487258732318878
+0.034056395292282104 7.578904478577897e-05 0.00023897104256320745 0.018025662750005722 0.038768280297517776
+0.04887538403272629 0.00026039310614578426 0.0008726685773581266 0.015687504783272743 0.030098870396614075
+0.03424772620201111 6.298570951912552e-05 0.0005655326531268656 0.014325260184705257 0.014035274274647236
+0.035518646240234375 0.0009895984549075365 0.0003524295170791447 0.011844303458929062 0.015469927340745926
+0.038106389343738556 0.0011935470392927527 0.0014814173337072134 0.011855296790599823 0.013636051677167416
+0.040866635739803314 0.00011766892566811293 0.0005399491055868566 0.0156412236392498 0.056336358189582825
+0.04332992807030678 0.0025118901394307613 0.0004917587502859533 0.022660480812191963 0.06543207168579102
+0.0432548001408577 0.002168146427720785 0.0006487348000518978 0.02038046158850193 0.06373853981494904
+0.03934987634420395 5.670416067005135e-05 0.00030800863169133663 0.01690548285841942 0.04559861496090889
+0.04378846287727356 0.0011691092513501644 0.0002555679529905319 0.01002573873847723 0.03348907455801964
+0.03393714874982834 0.0010276427492499352 0.00020506535656750202 0.011313275434076786 0.02061721868813038
+0.05611640214920044 0.00023237841378431767 0.0006098479498177767 0.012407908216118813 0.05098983645439148
+0.043595168739557266 8.009849989321083e-05 0.00027443229919299483 0.015029777772724628 0.017146864905953407
+0.035648196935653687 6.269797449931502e-05 0.0005420285160653293 0.02358481101691723 0.03094661608338356
+0.03624863177537918 0.00025389346410520375 0.0005838297074660659 0.021285787224769592 0.041988957673311234
+0.0451149046421051 0.0005108311306685209 0.0007575162453576922 0.018723972141742706 0.07425503432750702
+0.04116831719875336 0.00022330964566208422 0.0006715443450957537 0.025146275758743286 0.05362781509757042
+0.029936902225017548 0.00014439773804042488 0.0007878722390159965 0.02741323970258236 0.022599879652261734
+0.04760969430208206 0.00011350985732860863 0.00043987829121761024 0.014360723085701466 0.02648620493710041
+0.040393952280282974 6.88554355292581e-05 0.0006091141258366406 0.011721385642886162 0.02587391622364521
+0.04262467101216316 0.0010817993897944689 0.0002452904882375151 0.01329643465578556 0.01732979528605938
+0.03459760919213295 0.0009159284527413547 0.0004941332736052573 0.01625187136232853 0.01974625699222088
+0.03286433964967728 0.0001291776861762628 0.0003396294196136296 0.018997790291905403 0.04268352687358856
+0.03536933660507202 0.0010342150926589966 0.0007188505260273814 0.015292047522962093 0.03680204972624779
+0.03391272574663162 0.0003210797149222344 0.00028309033950790763 0.013018150813877583 0.030108120292425156
+0.03306152671575546 0.0001767922512954101 0.0005502952262759209 0.0168856680393219 0.029533052816987038
+0.02443435788154602 0.0002752594882622361 0.00042413576738908887 0.010960279032588005 0.007158090826123953
+0.030267339199781418 5.473990677273832e-05 0.0001395582512486726 0.008990949019789696 0.02538219280540943
+0.02886110730469227 0.0001039063063217327 0.00017746824596542865 0.012959696352481842 0.0263326708227396
+0.03560805320739746 0.00010270770872011781 0.0002389844012213871 0.017285281792283058 0.016128232702612877
+0.04075642302632332 6.247386772884056e-05 0.00027919537387788296 0.03422187641263008 0.03210047259926796
+0.045437805354595184 0.000169519436894916 0.00022002894547767937 0.01528649590909481 0.06182963773608208
+0.046938128769397736 0.00026250985683873296 0.0008405173430219293 0.022777201607823372 0.05362510308623314
+0.03289669007062912 9.977367153624073e-05 0.0005018420051783323 0.02330060489475727 0.03250521793961525
+0.032721422612667084 0.000611679395660758 0.0013885481748729944 0.01669560745358467 0.03923768550157547
+0.05090995877981186 6.785812001908198e-05 0.0003218288184143603 0.010899723507463932 0.06370532512664795
+0.04081428050994873 0.0014576204121112823 0.0011999206617474556 0.01523244846612215 0.021832991391420364
+0.0442764014005661 0.0005335331079550087 0.000976756913587451 0.05166018754243851 0.025737689808011055
+0.03730790317058563 0.0006736592622473836 0.0004413246351759881 0.009746253490447998 0.02862033061683178
+0.03473901003599167 0.001298949122428894 0.0006135907606221735 0.011637839488685131 0.04253177344799042
+0.03514464572072029 7.259123958647251e-05 0.0003985721559729427 0.023987742140889168 0.04147128388285637
+0.03138212114572525 0.00016057775064837188 0.00037484592758119106 0.01836296170949936 0.022099142894148827
+0.03878891095519066 7.86724267527461e-05 0.00028655974892899394 0.009669101797044277 0.03243567422032356
+0.028393540531396866 0.000166905956575647 0.00022590262233279645 0.013489443808794022 0.019096707925200462
+0.03337357938289642 0.00040975274168886244 0.000372526264982298 0.018420139327645302 0.016362374648451805
+0.03773147985339165 5.4144664318300784e-05 0.00024035820388235152 0.01822253130376339 0.04019102454185486
+0.029025735333561897 0.0005577760748565197 0.00046631376608274877 0.01354481652379036 0.032234981656074524
+0.030112851411104202 0.00020846594998147339 0.00015554649871774018 0.01246414054185152 0.024886604398489
+0.04007156193256378 0.00023540113761555403 0.0007613933994434774 0.015891503542661667 0.039679065346717834
+0.03502816706895828 0.00019262485147919506 0.0007197419181466103 0.018872784450650215 0.023580916225910187
+0.03372378274798393 0.0002949015761259943 0.00037480168975889683 0.014883959665894508 0.016884008422493935
+0.03714834153652191 0.0005529238260351121 0.0007817469304427505 0.013955622911453247 0.03868631273508072
+0.02641964703798294 0.00017610052600502968 0.0002393670001765713 0.01843569055199623 0.02400720864534378
+0.029329819604754448 0.0008491434273310006 0.0005818660138174891 0.02029479667544365 0.033056631684303284
+0.03797019273042679 7.49933096813038e-05 0.00021866116730961949 0.014997255988419056 0.04137871786952019
+0.03115822561085224 0.00033723824890330434 0.0011142484145238996 0.01368100568652153 0.02052381820976734
+0.039034996181726456 6.156102608656511e-05 0.00029187893960624933 0.013114013709127903 0.024786455556750298
+0.028943806886672974 0.0005315426969900727 0.0005384717951528728 0.010514694266021252 0.0240210872143507
+0.03566557914018631 0.0003764028078876436 0.0010048550320789218 0.013546200469136238 0.038139671087265015
+0.03253566473722458 0.00025137406191788614 0.0002720157499425113 0.015809891745448112 0.04342669993638992
+0.038378145545721054 0.001368192839436233 0.0008210102678276598 0.013832890428602695 0.0580347515642643
+0.03375621140003204 0.0004950809525325894 0.0003907519858330488 0.014439590275287628 0.02750151976943016
+0.038808465003967285 0.0003003292076755315 0.000394919712562114 0.01536179892718792 0.03467801213264465
+0.04476214200258255 0.0010690097697079182 0.000573761819396168 0.030572114512324333 0.042402442544698715
+0.04572620987892151 0.00026339091709814966 0.00025158506468869746 0.010246004909276962 0.048176780343055725
+0.036175407469272614 0.0001082708258763887 0.00043176766484975815 0.05248017609119415 0.036100782454013824
+0.03740701824426651 0.0002636619610711932 0.0005137251573614776 0.024038998410105705 0.035980481654405594
+0.036008916795253754 0.00027774233603850007 0.0005369248101487756 0.01770504005253315 0.0385897122323513
+0.028792859986424446 6.33963500149548e-05 0.00016268786566797644 0.010568168014287949 0.03042694739997387
+0.03906375542283058 0.0003087279328610748 0.0003848611086141318 0.010693378746509552 0.04355468973517418
+0.042537376284599304 0.00029080177773721516 0.0004123589023947716 0.012699678540229797 0.03769819438457489
+0.030765019357204437 0.00013016347656957805 0.00019277945102658123 0.008648844435811043 0.008120991289615631
+0.034305501729249954 0.00038597514503635466 0.00045760461944155395 0.007930146530270576 0.03415559232234955
+0.03281218558549881 0.00013631524052470922 0.00026126281591132283 0.01126308273524046 0.03456997498869896
+0.03908177837729454 0.00036385125713422894 0.001114614075049758 0.023536549881100655 0.04481818899512291
+0.04264906048774719 0.0002101900172419846 0.0003936886496376246 0.019882112741470337 0.053005047142505646
+0.039454136043787 0.0011123541044071317 0.0019571331795305014 0.01579490303993225 0.05635331943631172
+0.03943777456879616 4.592911864165217e-05 0.0007048638071864843 0.012285109609365463 0.042580973356962204
+0.04394424706697464 0.001589420484378934 0.00045247102389112115 0.015128512866795063 0.043234266340732574
+0.039175089448690414 0.0013147786958143115 0.00038995782961137593 0.012473711743950844 0.022538963705301285
+0.039372023195028305 5.173024328541942e-05 0.00038171320920810103 0.010356313548982143 0.04044201597571373
+0.03853737562894821 0.0004062448279000819 0.0005242855404503644 0.017984237521886826 0.05289380997419357
+0.03702085092663765 0.00018087019270751625 0.0006288738804869354 0.024845065549016 0.04383993148803711
+0.042541008442640305 0.0002175239787902683 0.0005843768012709916 0.020793838426470757 0.0622972697019577
+0.04441457986831665 0.0003465299669187516 0.00019801217422354966 0.011948410421609879 0.07174982875585556
+0.03573702648282051 5.319157207850367e-05 0.00021433521760627627 0.022059224545955658 0.013835998252034187
+0.04569663107395172 0.0003180584462825209 0.00035052013117820024 0.010568534955382347 0.018400168046355247
+0.041153259575366974 0.00030908852932043374 0.00027849574689753354 0.009179653599858284 0.04586884006857872
+0.04049959406256676 4.864723450737074e-05 0.00017194211250171065 0.013633853755891323 0.04119275510311127
+0.03633187711238861 0.00026594067458063364 0.00036708603147417307 0.015124253928661346 0.043995290994644165
+0.027866948395967484 0.00024939427385106683 0.0004326039634179324 0.023707512766122818 0.02371128834784031
+0.033094026148319244 7.487538823625073e-05 0.0005205136258155107 0.04246921092271805 0.02664719894528389
+0.02818518504500389 0.0001774055854184553 0.0002991070505231619 0.011809831485152245 0.02240629494190216
+0.03732284903526306 0.0009392306674271822 0.00022853886184748262 0.01628265157341957 0.02999887615442276
+0.028266360983252525 0.0005331992288120091 0.00044868350960314274 0.01732291840016842 0.026983313262462616
+0.04182804003357887 0.00020033936016261578 0.00019950450223404914 0.020458141341805458 0.031802937388420105
+0.036756157875061035 0.0010900190100073814 0.0005154038080945611 0.013765737414360046 0.030629677698016167
+0.040533408522605896 0.00014622858725488186 0.0002482148411218077 0.011511540040373802 0.04452012851834297
+0.027761511504650116 0.0004940846702083945 0.001972647849470377 0.021953444927930832 0.022314690053462982
+0.023425528779625893 9.448030323255807e-05 0.0005429394077509642 0.01607700064778328 0.013468554243445396
+0.034729283303022385 0.000859990541357547 0.0007248842157423496 0.01499994471669197 0.04363943263888359
+0.03626325726509094 0.00040704102138988674 0.0007202116539701819 0.015890512615442276 0.028250915929675102
+0.030193110927939415 0.0005215418059378862 0.0001559975789859891 0.015313009731471539 0.02342967689037323
+0.037927936762571335 0.0010848575038835406 0.0003287728177383542 0.012218546122312546 0.04255755990743637
+0.027434466406702995 0.0002058411919279024 0.0002849706797860563 0.013181905262172222 0.022815225645899773
+0.039899908006191254 0.00021943566389381886 0.0009974908316507936 0.017615336924791336 0.03556162863969803
+0.041823334991931915 0.0001923425297718495 0.0006424891762435436 0.022648504003882408 0.03429149091243744
+0.02755815163254738 0.00011763576185330749 0.00017863766697701067 0.013186597265303135 0.019210854545235634
+0.0327189601957798 0.0002261564222862944 0.0005639109876938164 0.010063402354717255 0.04291770979762077
+0.029483351856470108 9.0399946202524e-05 0.0005661664181388915 0.018414806574583054 0.022538693621754646
+0.03404529392719269 0.0005618334398604929 0.00016880103794392198 0.01673954911530018 0.030410870909690857
+0.04008037596940994 0.0005654824199154973 0.00047493004240095615 0.01498816441744566 0.04078185185790062
+0.03122766874730587 0.00013246967864688486 0.0002426019054837525 0.013780427165329456 0.02333603985607624
+0.030145462602376938 0.0006360453553497791 0.0005937269888818264 0.014070761390030384 0.03516903892159462
+0.02824767678976059 0.0003424994938541204 0.0004191694315522909 0.015994146466255188 0.014700128696858883
+0.040501996874809265 0.0001841811026679352 0.00033052233629859984 0.014960233122110367 0.03921964019536972
+0.02797558717429638 0.000532320817001164 0.0006735663046129048 0.011300169862806797 0.030894774943590164
+0.031568676233291626 7.292494410648942e-05 0.0002717355964705348 0.014268140308558941 0.03768463805317879
+0.03498414531350136 0.00031103339279070497 0.0006168578402139246 0.016557883471250534 0.03841257467865944
+0.0352478064596653 8.356290345545858e-05 0.0002461975673213601 0.016494842246174812 0.032270465046167374
+0.032256580889225006 0.0002215855201939121 0.0005276441224850714 0.011456851847469807 0.006248445715755224
+0.04011530429124832 0.00011118080874439329 0.0004375614516902715 0.01411249116063118 0.04287496209144592
+0.040771059691905975 0.00018667185213416815 0.00021584400383289903 0.01760775037109852 0.055256299674510956
+0.026669662445783615 0.00023377910838462412 0.00021863890287932009 0.018168965354561806 0.02070753648877144
+0.035001374781131744 6.201731594046578e-05 0.0004714815877377987 0.013211505487561226 0.030491402372717857
+0.029274869710206985 5.197382051846944e-05 0.0002992673253174871 0.012402014806866646 0.030537176877260208
+0.037217289209365845 0.00023756892187520862 0.0005850399029441178 0.016394156962633133 0.023937014862895012
+0.02484963648021221 7.390919199679047e-05 0.0005114340456202626 0.0169629268348217 0.017439866438508034
+0.04389813169836998 0.00042675199802033603 0.0002978056436404586 0.011531691066920757 0.06182638555765152
+0.039121199399232864 0.00014509379980154335 0.00018403676222078502 0.009368830360472202 0.022312019020318985
+0.03429181128740311 0.000257835112279281 0.0012107818620279431 0.013385992497205734 0.009607084095478058
+0.03471517935395241 8.367583359358832e-05 0.0008719435427337885 0.011723566800355911 0.044366270303726196
+0.0393073745071888 0.00030312128365039825 0.00032229808857664466 0.015030906535685062 0.05444633215665817
+0.03737645596265793 0.0005751803983002901 0.0005706593510694802 0.025214431807398796 0.05229129269719124
+0.02936621382832527 0.00019621444516815245 0.0007070348947308958 0.023207247257232666 0.03565877676010132
+0.04246562346816063 0.0002082394785247743 0.00045564130414277315 0.01088868547230959 0.04002659395337105
+0.04252948239445686 0.0015086567727848887 0.00017176734399981797 0.011440971866250038 0.04300829395651817
+0.0464133694767952 0.001140122301876545 0.0003746196161955595 0.016106821596622467 0.01164829172194004
+0.03881162405014038 0.00018753900076262653 0.00034737010719254613 0.01638876646757126 0.025882404297590256
+0.033751025795936584 0.001061520422808826 0.000582184293307364 0.012633743695914745 0.03133057802915573
+0.0510789155960083 0.00023489732120651752 0.00033284840174019337 0.015479162335395813 0.0853257104754448
+0.03174908459186554 0.0001679472188698128 0.0016787882195785642 0.024943815544247627 0.02965962328016758
+0.03254445269703865 0.00011316467862343416 0.00035860357456840575 0.019817965105175972 0.03468126431107521
+0.034906014800071716 0.0007170559256337583 0.000984423328191042 0.011363242752850056 0.02825668454170227
+0.03308062255382538 7.538850331911817e-05 0.0004575850907713175 0.011940089985728264 0.02450430952012539
+0.03929593414068222 0.0014866741839796305 0.0005193875404074788 0.011520116589963436 0.03573719412088394
+0.03430527076125145 0.0009072513785213232 0.0003185459354426712 0.015698766335844994 0.011771176941692829
+0.04378434270620346 0.00025433607515878975 0.0006116066942922771 0.009550671093165874 0.04272255301475525
+0.035583361983299255 0.0005446872673928738 0.0005453810445033014 0.01498360000550747 0.0369340181350708
+0.03374424949288368 0.00011001013626810163 0.0006225077086128294 0.021576402708888054 0.041890956461429596
+0.04013738036155701 0.0001348185760434717 0.0008713387651368976 0.016649767756462097 0.05968292057514191
+0.03263844549655914 0.0001537450443720445 0.00024360725365113467 0.014243788085877895 0.02264641970396042
+0.042498886585235596 0.0002370551083004102 0.00037009455263614655 0.01667410507798195 0.04642735421657562
+0.034364957362413406 0.0005575706600211561 0.0002893944620154798 0.014480243436992168 0.036473434418439865
+0.0320206955075264 0.0003192214644514024 0.00019236440130043775 0.010749476961791515 0.012286609038710594
+0.03606092929840088 0.0002594421384856105 0.00010454878065502271 0.014476924203336239 0.03763505443930626
+0.028738053515553474 9.323844278696924e-05 0.00024421326816082 0.012967103160917759 0.01845586486160755
+0.04910194128751755 0.00014055179781280458 0.0003902566386386752 0.01264396496117115 0.07630868256092072
+0.032113730907440186 0.0001323241158388555 0.00025256825028918684 0.01792876049876213 0.03652152046561241
+0.044896241277456284 0.00012765971769113094 0.00026083816192112863 0.015040627680718899 0.05342036485671997
+0.051833391189575195 9.113138366956264e-05 0.00019752990920096636 0.012407070025801659 0.058902423828840256
+0.03822848200798035 0.00011172608355991542 0.0002589961513876915 0.015416623093187809 0.03037935309112072
+0.030553389340639114 8.43425586936064e-05 0.00034756746026687324 0.012524271383881569 0.020684117451310158
+0.03173365816473961 9.250913717551157e-05 0.0004581179819069803 0.01613352634012699 0.025629475712776184
+0.04209113121032715 0.00020089138706680387 0.0002688353997655213 0.019810907542705536 0.049190182238817215
+0.03610290586948395 0.00025450490647926927 0.0003876874106936157 0.020409485325217247 0.040352724492549896
+0.037691880017519 8.798929047770798e-05 0.00021430986816994846 0.01573207788169384 0.04731810465455055
+0.028731485828757286 0.0001923277013702318 0.00020535661315079778 0.018825596198439598 0.03497536852955818
+0.03868190571665764 0.00022971471480559558 0.00035552907502278686 0.017813706770539284 0.04446602240204811
+0.0312889888882637 4.886702663497999e-05 0.0005839270888827741 0.010997189208865166 0.02769920974969864
+0.03767934441566467 4.590471144183539e-05 0.0005713627906516194 0.010179493576288223 0.021016083657741547
+0.03273387998342514 0.0002833499456755817 0.00014060811372473836 0.011783103458583355 0.019839908927679062
+0.03729860484600067 0.0003405441821087152 0.0005694285500794649 0.013540063053369522 0.018510449677705765
+0.030397657305002213 0.00021010312775615603 0.00027701500221155584 0.013527830131351948 0.022499632090330124
+0.046666838228702545 0.0010005934163928032 0.00038711557863280177 0.014229926280677319 0.071930892765522
+0.042557988315820694 0.0002947271568700671 0.00035628266050480306 0.024016182869672775 0.07426206767559052
+0.035280898213386536 0.00013234188372734934 0.0005585727631114423 0.029892779886722565 0.03917535021901131
+0.03736293688416481 0.0001315614063059911 0.00048402330139651895 0.01615901105105877 0.037868011742830276
+0.03365296125411987 0.00030303309904411435 0.00022830757370684296 0.017944660037755966 0.02212398685514927
+0.04074617102742195 0.0007333775865845382 0.0005263709463179111 0.09366798400878906 0.021111182868480682
+0.03299102932214737 0.00017713391571305692 0.0011231342796236277 0.02347327023744583 0.028206195682287216
+0.03543225675821304 0.000803985632956028 0.00037094217259436846 0.016501322388648987 0.04023255780339241
+0.03499964624643326 0.001190092647448182 0.0003545148647390306 0.014213238842785358 0.03719022870063782
+0.03175725042819977 0.000140954100061208 0.0010762176243588328 0.023538023233413696 0.03015049174427986
+0.03442218899726868 0.00010256210953230038 0.0006968511152081192 0.021878253668546677 0.03127100318670273
+0.029055863618850708 8.091278141364455e-05 0.0002769393031485379 0.012228166684508324 0.020915936678647995
+0.024750206619501114 6.255506741581485e-05 0.0009374052751809359 0.01322139147669077 0.015229858458042145
+0.024620912969112396 0.0002299421321367845 0.0005400344380177557 0.01498000044375658 0.012852739542722702
+0.025364950299263 0.0005782394437119365 0.0003430162323638797 0.0118180513381958 0.01931583695113659
+0.03131406754255295 0.00031962976208887994 0.0003938851004932076 0.010449567809700966 0.019372709095478058
+0.0285109244287014 5.9466747188707814e-05 0.0003711754397954792 0.011590009555220604 0.023838715627789497
+0.032267387956380844 7.775691483402625e-05 0.00018725280824583024 0.01245973166078329 0.02293301746249199
+0.04760751128196716 7.99045228632167e-05 0.00021589112293440849 0.01603124290704727 0.032415006309747696
+0.036346666514873505 0.0001225342566613108 0.00034804706228896976 0.013294999487698078 0.026901444420218468
+0.029472481459379196 0.00013619432866107672 0.0004981712554581463 0.013031219132244587 0.014997179619967937
+0.03421933948993683 0.00011064161662943661 0.0005584681057371199 0.012150771915912628 0.0362415574491024
+0.03567430004477501 6.94112604833208e-05 0.0004116093332413584 0.019125154241919518 0.04110465943813324
+0.03318578377366066 7.018958422122523e-05 0.0002315563033334911 0.017360281199216843 0.030698012560606003
+0.03860560059547424 5.491043339134194e-05 0.0003155143931508064 0.012407473288476467 0.034196723252534866
+0.040025196969509125 0.0001703200105112046 0.00016507026157341897 0.01135584618896246 0.032845914363861084
+0.04237547516822815 0.00040256400825455785 0.00033144999179057777 0.016682038083672523 0.049424588680267334
+0.0380813330411911 6.01886713411659e-05 0.0002504259755369276 0.016537213698029518 0.041170068085193634
+0.02935369685292244 0.0003659506037365645 0.0006910133524797857 0.011815125122666359 0.019749069586396217
+0.028278056532144547 0.00024002073041629046 0.0005421345704235137 0.010273268446326256 0.02093346230685711
+0.0309701319783926 0.00025952019495889544 0.00027999075246043503 0.017211494967341423 0.012997443787753582
+0.03793107718229294 0.0005644875927828252 0.000444704492110759 0.012794632464647293 0.031443070620298386
+0.03590318560600281 4.6558259782614186e-05 0.00025813892716541886 0.014118829742074013 0.03209371492266655
+0.030455512925982475 0.00046669840230606496 0.0006793909124098718 0.012441190890967846 0.02802976220846176
+0.023546431213617325 0.00019266718300059438 0.0005073233041912317 0.01614871621131897 0.006089367903769016
+0.035516560077667236 0.00016377614520024508 0.00022200285457074642 0.012749632820487022 0.04339905455708504
+0.041653022170066833 0.000587280432227999 0.0007374075939878821 0.015334094874560833 0.06751907616853714
+0.028237540274858475 8.30297140055336e-05 0.00032469656434841454 0.021065954118967056 0.02153368853032589
+0.03928929567337036 0.0008316478342749178 0.0012759251985698938 0.015517149120569229 0.04779946058988571
+0.03719210997223854 0.0002121909346897155 0.0006565466756001115 0.01635824330151081 0.0488787516951561
+0.037213727831840515 0.0008231429383158684 0.0005676824366673827 0.015616833232343197 0.03191651031374931
+0.03281945735216141 0.0009683892712928355 0.0009104009950533509 0.016996145248413086 0.023123156279325485
+0.0334542915225029 0.0001218182806042023 0.00020905130077153444 0.012194672599434853 0.03445175290107727
+0.037207767367362976 0.0004268433840479702 0.00022686042939312756 0.012997773475944996 0.018686147406697273
+0.03059670701622963 0.00020202381710987538 0.0011033196933567524 0.01543297991156578 0.005826557520776987
+0.029209373518824577 0.00020015262998640537 0.001105114701204002 0.013573269359767437 0.02914264425635338
+0.033578529953956604 0.0003516669094096869 0.000543687492609024 0.018905606120824814 0.03788746893405914
+0.040408045053482056 0.00031583019881509244 0.00030147991492412984 0.014595195651054382 0.06061582639813423
+0.02900591306388378 4.998368240194395e-05 0.00043510549585334957 0.019528819248080254 0.030953340232372284
+0.035323552787303925 0.0003012540691997856 0.0007492221775464714 0.013257082551717758 0.03513848036527634
+0.04726377874612808 0.0004896901664324105 0.0010107874404639006 0.02722792886197567 0.02216021902859211
+0.041865091770887375 0.0008846365381032228 0.00040431623347103596 0.00864845048636198 0.018887920305132866
+0.03005467914044857 0.0007371759274974465 0.0002229168894700706 0.017745619639754295 0.011673188768327236
+0.028932631015777588 5.145469185663387e-05 0.00034282036358490586 0.014432728290557861 0.018304897472262383
+0.02978048101067543 0.0010862834751605988 0.00038496917113661766 0.02464703470468521 0.035100895911455154
+0.028896404430270195 0.001254187198355794 0.0006092985277064145 0.025618940591812134 0.03201825171709061
+0.041329991072416306 0.00015907007036730647 0.0004432958085089922 0.02073969505727291 0.06615009903907776
+0.02483006939291954 0.000255166320130229 0.00035737614962272346 0.0172243881970644 0.0174317117780447
+0.03896009922027588 0.0003011659428011626 0.00021203978394623846 0.01396955642849207 0.008067944087088108
+0.03214649856090546 0.00015188752149697393 0.0001750973897287622 0.012177874334156513 0.01697792112827301
+0.025718320161104202 0.0002550046774558723 0.00016590602172072977 0.013810244388878345 0.012086419388651848
+0.03558269888162613 0.00010242111602565274 0.0004417126183398068 0.01312161423265934 0.031212782487273216
+0.03385280817747116 0.0002247139491373673 0.0003950998361688107 0.01566315069794655 0.03208107501268387
+0.03043385222554207 0.0007451553246937692 0.00039651579572819173 0.01223214715719223 0.027035344392061234
+0.034579548984766006 0.0003312558983452618 0.000427925813710317 0.018419265747070312 0.04676719009876251
+0.035894230008125305 5.5023607274051756e-05 0.00035031713196076453 0.021999036893248558 0.03183135390281677
+0.035964518785476685 4.033435834571719e-05 0.0002843734691850841 0.011437805369496346 0.0374312698841095
+0.03884873539209366 0.00024894767557270825 0.000298238213872537 0.011395097710192204 0.017693055793642998
+0.03477595001459122 0.0009183208458125591 0.00013558704813476652 0.013238264247775078 0.01334806066006422
+0.02980397269129753 0.00036787783028557897 0.0001680707064224407 0.015011665411293507 0.01882312260568142
+0.02986951172351837 0.00020273178233765066 0.00021782671683467925 0.025807790458202362 0.03143018111586571
+0.038580670952796936 0.0008012373582459986 0.0002313862059963867 0.015349417924880981 0.05116691440343857
+0.034797459840774536 0.00018358307715971023 0.00018859993724618107 0.01621078886091709 0.01571907289326191
+0.04242534935474396 5.308518302626908e-05 0.00027961304294876754 0.01132874097675085 0.05888519063591957
+0.027570417150855064 0.00012058050924679264 0.0001793240080587566 0.01476812269538641 0.021205849945545197
+0.028935326263308525 4.35494766861666e-05 0.0002822466194629669 0.016931964084506035 0.029357057064771652
+0.036063432693481445 0.00010649851174093783 0.0006833302322775126 0.015874095261096954 0.030248768627643585
+0.02374839410185814 4.201315823593177e-05 0.0008257495355792344 0.014638722874224186 0.00744421174749732
+0.03414979577064514 0.00035393270081840456 0.0003175562305841595 0.011672244407236576 0.03867174685001373
+0.030170531943440437 0.0009070899104699492 0.000311807612888515 0.01170616876333952 0.028399452567100525
+0.022243157029151917 0.00023202181910164654 0.00038152391789481044 0.014496438205242157 0.014078846201300621
+0.03204068914055824 0.0006113729323260486 0.0002066386805381626 0.013938704505562782 0.03943292796611786
+0.03192433714866638 0.001197969657368958 0.0003303065604995936 0.011932425200939178 0.037224821746349335
+0.029593780636787415 8.036258077481762e-05 0.0003641076909843832 0.010971269570291042 0.027277614921331406
+0.047234851866960526 0.00044906517723575234 0.0013879769248887897 0.011960756033658981 0.05346163362264633
+0.0357365757226944 0.00010927916446235031 0.00042545809992589056 0.018134530633687973 0.026204025372862816
+0.02923041582107544 0.0004457724862731993 0.0006363403517752886 0.01720820553600788 0.023854373022913933
+0.029873166233301163 0.00023065407003741711 0.0006931682582944632 0.013433651067316532 0.014512571506202221
+0.030837055295705795 0.0005280128680169582 0.00034458961454220116 0.014047332108020782 0.028394481167197227
+0.03359062224626541 0.0008659434388391674 0.0007785719353705645 0.021310511976480484 0.03796636313199997
+0.03125864639878273 0.000121580611448735 0.0006405371823348105 0.012622103095054626 0.03410201892256737
+0.03966338187456131 0.0007121206726878881 0.001375088351778686 0.01114518754184246 0.05216110125184059
+0.029547180980443954 6.602391658816487e-05 0.0002126191247953102 0.01166311465203762 0.02617727220058441
+0.03231485188007355 0.0006848917109891772 0.0007828989182598889 0.01409850362688303 0.01865304447710514
+0.025508461520075798 9.623686491977423e-05 0.00030619348399341106 0.007873613387346268 0.011921384371817112
+0.03768192604184151 0.0003310743486508727 0.0003106570220552385 0.00794621929526329 0.045826658606529236
+0.028830818831920624 0.0001324891927652061 0.0002710103872232139 0.014511403627693653 0.01727655716240406
+0.03677963837981224 6.481575837824494e-05 0.000592023425269872 0.022779546678066254 0.047565117478370667
+0.025122012943029404 0.00012781740224454552 0.0003457434067968279 0.015310505405068398 0.018704665824770927
+0.035637419670820236 8.568223711336032e-05 0.0003865062608383596 0.012697779573500156 0.04321303591132164
+0.037530943751335144 0.00025187639403156936 0.00033183913910761476 0.01504728477448225 0.031140420585870743
+0.03512256592512131 0.00047571788309141994 0.00030511824297718704 0.024745037779211998 0.01790868304669857
+0.03996678441762924 6.484017649199814e-05 0.0004436741874087602 0.017157450318336487 0.026076508685946465
+0.028020191937685013 0.00028578154160641134 0.0007522411178797483 0.01187498401850462 0.03180939704179764
+0.039578743278980255 0.0001778277801349759 0.0003651466977316886 0.016070308163762093 0.05777230113744736
+0.027903344482183456 0.00038562313420698047 0.0011561526916921139 0.023535719141364098 0.020034056156873703
+0.03885355591773987 0.0004371737304609269 0.00031726242741569877 0.01564740762114525 0.03747259080410004
+0.05894707143306732 0.0009632842848077416 0.0012048858916386962 0.010818402282893658 0.08455768972635269
+0.035238537937402725 0.0004553311155177653 0.0007246130262501538 0.010152122005820274 0.02222885563969612
+0.04029694199562073 0.0013519622152671218 0.0021498266141861677 0.015401264652609825 0.04283149167895317
+0.030642833560705185 0.00016395222337450832 0.00015902220911812037 0.012656865641474724 0.012946576811373234
+0.028217384591698647 0.0018454095115885139 0.0017657640855759382 0.010746876709163189 0.025935614481568336
+0.03209308162331581 4.987888314644806e-05 0.00013118551578372717 0.014507584273815155 0.030186530202627182
+0.030391480773687363 0.0014584417222067714 0.0014378927880898118 0.02121540531516075 0.013080173172056675
+0.03252391517162323 0.00028316042153164744 0.00039034971268847585 0.025128928944468498 0.027778811752796173
+0.03516445308923721 0.0008037528023123741 0.0016559931682422757 0.015867257490754128 0.03637782856822014
+0.03389711678028107 0.00042247114470228553 0.0010671444470062852 0.02068965509533882 0.029522618278861046
+0.028025919571518898 0.0005491471383720636 0.00038023764500394464 0.019511401653289795 0.017485445365309715
+0.025403767824172974 0.0009537700680084527 0.0008611952071078122 0.017247814685106277 0.021599527448415756
+0.03386906534433365 7.531881419708952e-05 0.00022409192752093077 0.011755872517824173 0.025006704032421112
+0.03177504613995552 0.0006599580519832671 0.00045306471292860806 0.012582606635987759 0.03461109474301338
+0.03728058189153671 0.00015701772645115852 0.00019306072499603033 0.013984224759042263 0.04741009324789047
+0.0342683307826519 0.00010282567382091656 0.0005661554750986397 0.014429398812353611 0.028319723904132843
+0.02864580787718296 9.348437743028626e-05 0.0001688963093329221 0.009450419805943966 0.02017906680703163
+0.03228769823908806 8.950500341597944e-05 0.00019966809486504644 0.007137348875403404 0.03287466615438461
+0.02607995644211769 8.48571871756576e-05 0.00024687510449439287 0.011781672947108746 0.026128780096769333
+0.029420342296361923 0.00010348555952077731 0.00037176039768382907 0.020120518282055855 0.034503273665905
+0.024042952805757523 0.00021129834931343794 0.0007335373666137457 0.014554949477314949 0.011410746723413467
+0.03195264935493469 0.00017488595040049404 0.000991061911918223 0.009844902902841568 0.029846254736185074
+0.03655783832073212 0.0006392793729901314 0.00026046682614833117 0.011190234683454037 0.03828039392828941
+0.03378025069832802 0.0011745709925889969 0.0011299812467768788 0.018656214699149132 0.0365406796336174
+0.022283822298049927 0.00013440284237731248 0.00029030180303379893 0.014708339236676693 0.021994996815919876
+0.027848348021507263 0.0014607060002163053 0.001157138729467988 0.011170606128871441 0.01426300685852766
+0.020803142338991165 0.00016797921853139997 0.00011270450340816751 0.01078187208622694 0.014864001423120499
+0.03799285739660263 0.0007875166484154761 0.001896094297990203 0.017034944146871567 0.03750433400273323
+0.02953861653804779 9.555184078635648e-05 0.0004053988086525351 0.015914766117930412 0.018742622807621956
+0.024764643982052803 0.0008484846912324429 0.0007385035860352218 0.011547897942364216 0.0183302853256464
+0.026989886537194252 0.00027853267965838313 0.00042292510624974966 0.011562444269657135 0.016741132363677025
+0.034946732223033905 0.0005659135640598834 0.0006000104476697743 0.014098168350756168 0.024357028305530548
+0.03410457447171211 0.0006091960822232068 0.0005277139134705067 0.015377545729279518 0.021855594590306282
+0.02577674388885498 0.00012140146282035857 0.0005446119466796517 0.01280116941779852 0.024128511548042297
+0.02230348251760006 0.0005228911759331822 0.0007963338284753263 0.013073770329356194 0.021608682349324226
+0.034729402512311935 6.441340519813821e-05 0.0003480466839391738 0.013950401917099953 0.044690631330013275
+0.02960960380733013 0.00034074034192599356 0.00045675269211642444 0.012698572129011154 0.03316652029752731
+0.03769221156835556 0.00015463675663340837 0.000154808527440764 0.008157674223184586 0.048171184957027435
+0.04139602556824684 4.7209869080688804e-05 0.00016090991266537458 0.00900661014020443 0.031993187963962555
+0.038104601204395294 9.050618973560631e-05 0.0002087213797494769 0.012005957774817944 0.034211885184049606
+0.03252396732568741 6.323725392576307e-05 0.0002521067508496344 0.01608153246343136 0.030348625034093857
+0.04089336842298508 6.405132444342598e-05 0.00020123622380197048 0.011284379288554192 0.0699755921959877
+0.04327509552240372 7.094045577105135e-05 0.00026148484903387725 0.01563059724867344 0.07920292019844055
+0.025495510548353195 0.00010614881466608495 0.0002309534465894103 0.019161095842719078 0.02455126866698265
+0.03432980924844742 0.00015562227054033428 0.0006927901413291693 0.04428932070732117 0.016868967562913895
+0.027510065585374832 0.00022233360505197197 0.0001747250062180683 0.01096490491181612 0.006725530605763197
+0.031260743737220764 0.0011483868584036827 0.0006133111310191453 0.01092051062732935 0.01652812957763672
+0.027315698564052582 0.00024634768487885594 0.0005546670872718096 0.019188325852155685 0.033508967608213425
+0.02711532451212406 0.0010804791236296296 0.0009173299768008292 0.018412290140986443 0.027653520926833153
+0.034016963094472885 0.0016874552238732576 0.0017060984391719103 0.01689787022769451 0.033871300518512726
+0.03401925414800644 0.00012042854359606281 0.00016637305088806897 0.008859570138156414 0.040302421897649765
+0.03599298745393753 0.0022316037211567163 0.0016651300247758627 0.01041447278112173 0.02759961411356926
+0.030448688194155693 0.0003098447050433606 0.00022842560429126024 0.010000202804803848 0.022448904812335968
+0.026759590953588486 0.0006853418308310211 0.0007932235603220761 0.00933999102562666 0.02408210001885891
+0.03975606709718704 0.0003472273820079863 0.00015383843856398016 0.015076308511197567 0.03956243023276329
+0.03521853685379028 0.00011663963960018009 0.0009584941435605288 0.015876097604632378 0.0321480929851532
+0.04436488822102547 6.273808685364202e-05 0.0006671571754850447 0.012789204716682434 0.052046239376068115
+0.03245190531015396 0.0002554329694248736 0.00015877540863584727 0.014040789566934109 0.04714728519320488
+0.022305380553007126 0.00024778684019111097 0.00028227947768755257 0.015555436722934246 0.01862008310854435
+0.027854260057210922 6.0999966080999e-05 0.00013958566705696285 0.015405388548970222 0.02971232309937477
+0.029879191890358925 0.00022074552543926984 0.0001303995813941583 0.01383514516055584 0.014434625394642353
+0.028374526649713516 0.00022477685706689954 0.00020316036534495652 0.009480510838329792 0.0158078670501709
+0.03189344331622124 0.00018592140986584127 0.00017616551485843956 0.011602669954299927 0.023140843957662582
+0.03373635560274124 5.7257268053945154e-05 0.00020205325563438237 0.018688281998038292 0.03549707308411598
+0.032099589705467224 0.00031009645317681134 0.0002936196397058666 0.024149850010871887 0.04292900487780571
+0.02132621593773365 0.0006621935172006488 0.0006756652728654444 0.012945177033543587 0.011790980584919453
+0.027124004438519478 5.374878674047068e-05 0.00017690600361675024 0.012027079239487648 0.024577246978878975
+0.027699090540409088 0.000770806975197047 0.0004965173429809511 0.012023949064314365 0.03398549184203148
+0.025606928393244743 0.00026121680275537074 0.00016112372395582497 0.014572303742170334 0.019052919000387192
+0.023615358397364616 0.0003143803041893989 0.0006966623477637768 0.011076816357672215 0.016867395490407944
+0.02951684035360813 0.00014923013804946095 0.0002668431552592665 0.009030403569340706 0.011630117893218994
+0.04622192680835724 0.0006123404018580914 0.0013193093473091722 0.012129568494856358 0.024680327624082565
+0.020616523921489716 0.0002388624125160277 0.0011849664151668549 0.012257436290383339 0.007322084624320269
+0.04454294964671135 0.000764331838581711 0.00020018925715703517 0.012169770896434784 0.06325177848339081
+0.045058466494083405 0.001042512827552855 0.0004129005246795714 0.018384922295808792 0.07831989228725433
+0.028420323505997658 0.00011390903819119558 0.0003678407520055771 0.02768285945057869 0.02135290391743183
+0.031153159216046333 6.881851732032374e-05 0.00022787618217989802 0.015384014695882797 0.036260977387428284
+0.036207109689712524 0.00029575408552773297 0.00037850713124498725 0.009885350242257118 0.020131271332502365
+0.029821714386343956 0.0014285010984167457 0.0002051490155281499 0.011374704539775848 0.007971691899001598
+0.03270701318979263 0.0005166673799976707 0.0005263916100375354 0.015323600731790066 0.016961710527539253
+0.02422543242573738 0.0007605488644912839 0.000532096135430038 0.013633464463055134 0.005390237085521221
+0.03154970705509186 0.0014773613074794412 0.0007603284902870655 0.008331084623932838 0.030909975990653038
+0.036133378744125366 0.00010970813309540972 0.0007465843809768558 0.013554445467889309 0.0361943319439888
+0.03297923505306244 0.0008244556374847889 0.002376365475356579 0.0227944515645504 0.045537833124399185
+0.04726262763142586 0.00020164242596365511 0.00039255819865502417 0.015931464731693268 0.07686205208301544
+0.03765954449772835 0.0013717265101149678 0.0021619927138090134 0.014874446205794811 0.0531107559800148
+0.030632812529802322 9.419245907338336e-05 0.001015205169096589 0.04986086115241051 0.024341575801372528
+0.033933524042367935 0.0017816026229411364 0.00013833882985636592 0.007807672955095768 0.03373260051012039
+0.040971651673316956 0.0024164742790162563 0.0005377677734941244 0.008403263054788113 0.0348978117108345
+0.041977256536483765 0.0001379920431645587 0.00021775254572276026 0.009346511214971542 0.01231997087597847
+0.03131747245788574 0.0013327464694157243 0.000359641679096967 0.01580161601305008 0.01998993568122387
+0.038598187267780304 0.0016742381267249584 0.00044966989662498236 0.01581994630396366 0.05583509802818298
+0.02918156236410141 0.00013783756003249437 0.0010609475430101156 0.02949812076985836 0.039789583534002304
+0.035956867039203644 7.776570419082418e-05 0.0004873344150837511 0.020556068047881126 0.05098666623234749
+0.028351034969091415 4.833396815229207e-05 0.0002981197030749172 0.014861545525491238 0.026649758219718933
+0.03466900810599327 0.00021426184684969485 0.0003126913507003337 0.01436674315482378 0.015272143296897411
+0.025123950093984604 0.0008271408150903881 0.0002314666344318539 0.013851368799805641 0.003872748464345932
+0.03920464590191841 0.0001931923907250166 0.00015370378969237208 0.010692457668483257 0.023210518062114716
+0.04228110983967781 0.00015078835713211447 0.00018421154527459294 0.009951862506568432 0.058378007262945175
+0.022756291553378105 0.0002478614333085716 0.0001843128993641585 0.017021285369992256 0.02345489338040352
+0.025344401597976685 0.0002370610018260777 0.00019348578643985093 0.026926567777991295 0.012047864496707916
+0.028137583285570145 0.0002061855047941208 0.00016103220696095377 0.01056925393640995 0.02585987001657486
+0.03703835979104042 5.811062146676704e-05 0.00013382591714616865 0.01007851492613554 0.03816821798682213
+0.0246509350836277 0.0004502233932726085 0.00015762739349156618 0.012285474687814713 0.022327568382024765
+0.038672663271427155 0.0002136058610631153 0.0001771131792338565 0.013054708018898964 0.055403176695108414
+0.031268902122974396 6.058149301679805e-05 0.0005290049593895674 0.015428978018462658 0.0293497983366251
+0.022885262966156006 6.514288543257862e-05 0.00038430432323366404 0.010582080110907555 0.008548730984330177
+0.02775747887790203 0.00030673283617943525 0.0002832879254128784 0.009424746036529541 0.02161211334168911
+0.0324820950627327 0.00047037668991833925 0.0012885179603472352 0.008688378147780895 0.043670475482940674
+0.020444078370928764 0.0002540902642067522 0.0008698176825419068 0.01623784378170967 0.003926924429833889
+0.03221547231078148 0.0025245819706469774 0.0008014284539967775 0.010999330319464207 0.02673608809709549
+0.03835374116897583 0.000821699621155858 0.0005491378251463175 0.03490273281931877 0.038979947566986084
+0.02095053344964981 0.0010298630222678185 0.0013965179678052664 0.017518971115350723 0.013931064866483212
+0.04024515673518181 0.0009764097048901021 0.0007493095472455025 0.015885375440120697 0.03991919755935669
+0.03520781546831131 0.00016163584950845689 0.0014998629922047257 0.039316367357969284 0.016642918810248375
+0.027135131880640984 0.00019426763174124062 0.001578696072101593 0.009931705892086029 0.02259327657520771
+0.028734717518091202 0.0006887008785270154 0.000569655152503401 0.01311673503369093 0.02455185540020466
+0.027000460773706436 0.0005485614528879523 0.0007869470864534378 0.015145670622587204 0.027912860736250877
+0.028258806094527245 0.00038385725929401815 0.00041067576967179775 0.009934577159583569 0.036916643381118774
+0.03094835951924324 0.0007645689183846116 0.0003251198213547468 0.011501560918986797 0.029479719698429108
+0.03241850808262825 4.645804801839404e-05 0.0004538039502222091 0.01728370040655136 0.03298977389931679
+0.03204528987407684 9.879021672531962e-05 0.0003249702858738601 0.01157579105347395 0.032357051968574524
+0.031224675476551056 7.104202086338773e-05 0.00012156738375779241 0.009737884625792503 0.03176550939679146
+0.03528030216693878 8.821609662845731e-05 0.00014784229279030114 0.013588830828666687 0.020842140540480614
+0.026363424956798553 0.0001186029112432152 0.000263674883171916 0.014286596328020096 0.022239187732338905
+0.025072041898965836 0.00010249943443341181 0.00021460223069880158 0.017216671258211136 0.030180472880601883
+0.03288525715470314 0.0003593949950300157 0.0004735140537377447 0.011386060155928135 0.03960211202502251
+0.027407322078943253 5.554951712838374e-05 0.00015867587353568524 0.011821914464235306 0.02220233902335167
+0.031864531338214874 0.00038527543074451387 0.0002171803789678961 0.012571596540510654 0.02941807173192501
+0.02969990484416485 0.0002003626141231507 0.00013382310862652957 0.012218745425343513 0.03218298405408859
+0.029053019359707832 4.889792762696743e-05 0.00025957042817026377 0.013257120735943317 0.02678006887435913
+0.026201173663139343 0.0001409338292432949 0.00040241150418296456 0.008273891173303127 0.02577700838446617
+0.03196335211396217 5.011080065742135e-05 0.00025574423489160836 0.009840107522904873 0.02448561228811741
+0.027401212602853775 0.00012161902122898027 0.0003324821009300649 0.008904366753995419 0.02174832485616207
+0.03704259917140007 5.000819146516733e-05 0.0001915462635224685 0.011437452398240566 0.01797974854707718
+0.029015924781560898 0.0003040014998987317 0.0001675350358709693 0.010298460721969604 0.015936413779854774
+0.03467287868261337 0.00010418186866445467 0.00013953063171356916 0.00787337962538004 0.043773528188467026
+0.030126065015792847 0.00014522856508847326 0.0005331944557838142 0.014356483705341816 0.04172804579138756
+0.028433755040168762 4.392015762277879e-05 0.00033425725996494293 0.021138830110430717 0.031481776386499405
+0.037629976868629456 0.0005045905709266663 0.0008958496619015932 0.01379266194999218 0.040172480046749115
+0.02676633931696415 5.7427678257226944e-05 0.0005571605288423598 0.009966784156858921 0.022476674988865852
+0.03222258761525154 0.001171523705124855 0.000435372261563316 0.01149227935820818 0.03519574552774429
+0.0193520188331604 0.0007660400005988777 0.0005035598878748715 0.016126815229654312 0.005250199232250452
+0.024956900626420975 0.00039351455052383244 0.0004763175384141505 0.012317932210862637 0.01544052641838789
+0.03190405294299126 0.0009842606959864497 0.0008825981640256941 0.009753616526722908 0.035652391612529755
+0.034110430628061295 8.310053817695007e-05 0.00042112459777854383 0.0158067736774683 0.04129147529602051
+0.02818773314356804 0.0004476092290133238 0.0008848169818520546 0.01760862022638321 0.026180487126111984
+0.02040262706577778 0.00010501724318601191 0.0001943158422363922 0.008602031506597996 0.012132160365581512
+0.021085558459162712 0.00029971564072184265 0.0004832351696677506 0.007657080423086882 0.01519443653523922
+0.026105022057890892 0.00015043398889247328 0.00022421796165872365 0.013877120800316334 0.02007029578089714
+0.03325297683477402 0.0002995058021042496 0.0003124788054265082 0.01346577052026987 0.03303040936589241
+0.04296531528234482 8.13105289125815e-05 0.00017168876365758479 0.008852322585880756 0.06402483582496643
+0.03185552358627319 0.00014689892122987658 0.00022472150158137083 0.011028219945728779 0.03362981602549553
+0.030494939535856247 5.318508192431182e-05 0.0001941232185345143 0.01810402423143387 0.03970681130886078
+0.023928478360176086 5.168787538423203e-05 0.00015676955808885396 0.013218953274190426 0.026917284354567528
+0.028524765744805336 6.0782491345889866e-05 0.0002711217966862023 0.009946723468601704 0.033570051193237305
+0.03226444870233536 9.161895286524668e-05 9.0049572463613e-05 0.0117399413138628 0.028952432796359062
+0.02666451781988144 0.00029648150666616857 0.00043546719825826585 0.016625426709651947 0.020979613065719604
+0.03282727301120758 7.833025301806629e-05 0.000189448386663571 0.01433014776557684 0.04199487715959549
+0.027531743049621582 0.00040417598211206496 0.00021164774079807103 0.010294907726347446 0.02199576236307621
+0.022309135645627975 0.0001350821548840031 0.00016340061847586185 0.011108802631497383 0.004704990424215794
+0.024563144892454147 0.00019936467288061976 0.000672771071549505 0.012773006223142147 0.01391909271478653
+0.027767812833189964 9.619937918614596e-05 0.0006195182213559747 0.018066920340061188 0.02889767847955227
+0.03795246779918671 0.0004854958679061383 0.0005863031256012619 0.017692184075713158 0.04711776226758957
+0.031206144019961357 0.00045861670514568686 0.0010547790443524718 0.015017492696642876 0.037464968860149384
+0.025391949340701103 0.00024941156152635813 0.00033732541487552226 0.016190864145755768 0.012443475425243378
+0.03038303554058075 0.0010081662330776453 0.00020110918558202684 0.011428235098719597 0.02402256429195404
+0.01930164359509945 0.00025728807668201625 0.00021293811732903123 0.016525859013199806 0.0018598074093461037
+0.045368947088718414 4.3231404561083764e-05 0.00022886437363922596 0.010154119692742825 0.05730500444769859
+0.026546183973550797 4.430410990607925e-05 0.00034715610672719777 0.012099492363631725 0.024204861372709274
+0.02649880200624466 0.00018454750534147024 0.00034504447830840945 0.014882833696901798 0.024190101772546768
+0.021301191300153732 0.0008598455460742116 0.00011435429769335315 0.009237319231033325 0.012114167213439941
+0.030405618250370026 0.00022655284556094557 0.00010120793740497902 0.008440648205578327 0.02895490638911724
+0.028049953281879425 0.0004148669249843806 0.0008561156573705375 0.010632600635290146 0.018347132951021194
+0.02322084829211235 0.00028934775036759675 0.0004685772582888603 0.012561555951833725 0.012095591053366661
+0.029199350625276566 0.00019674829673022032 0.0007537170313298702 0.010985512286424637 0.03507046028971672
+0.027208559215068817 0.00012920233712065965 0.0008100222330540419 0.00932992808520794 0.026515251025557518
+0.03802557662129402 0.0002817523491103202 0.00020947103621438146 0.059948086738586426 0.028591258451342583
+0.03811468929052353 0.00043179592466913164 0.00029702988103963435 0.009590824134647846 0.0203899797052145
+0.036923136562108994 4.026344322483055e-05 0.00013236617087386549 0.011151045560836792 0.025560220703482628
+0.026643643155694008 0.00029518603696487844 0.00021341680258046836 0.013696486130356789 0.02534455806016922
+0.025995314121246338 0.00029698657453991473 0.0002844721602741629 0.019469600170850754 0.02355976775288582
+0.03793651610612869 0.0001563409750815481 0.0002602519525680691 0.015299558639526367 0.054157502949237823
+0.0267751757055521 4.772200190927833e-05 0.0003069648228120059 0.009542915038764477 0.022427501156926155
+0.026337217539548874 0.0003246109699830413 0.00015491781232412905 0.007748211733996868 0.02768082171678543
+0.03218014910817146 0.0006771694752387702 0.00016017410962376744 0.009343582205474377 0.03883388638496399
+0.03299987316131592 0.0001295185647904873 0.0002705516235437244 0.011821635998785496 0.02655021660029888
+0.03390049934387207 0.00015734169573988765 0.00027911190409213305 0.012731709517538548 0.021134287118911743
+0.0401613712310791 5.583686652244069e-05 0.0004392186237964779 0.010478684678673744 0.025958187878131866
+0.02818007580935955 0.00029249853105284274 0.003426174633204937 0.011846520006656647 0.0021879749838262796
+0.029461927711963654 0.00038241289439611137 0.0016013721469789743 0.01272741798311472 0.03838173672556877
+0.04118967428803444 0.003093599108979106 0.0010501574724912643 0.023948488757014275 0.04629134014248848
+0.033094678074121475 0.0018732844619080424 0.0018738344078883529 0.019364692270755768 0.043023381382226944
+0.02829884923994541 0.00037426588824018836 0.0008010461460798979 0.015101990662515163 0.024816768243908882
+0.03267695754766464 0.003529295092448592 0.0003276366333011538 0.013964728452265263 0.041589733213186264
+0.029866531491279602 0.0016513365553691983 0.0007684442680329084 0.016193732619285583 0.028743892908096313
+0.03128162771463394 6.59756624372676e-05 0.0030028775800019503 0.011565816588699818 0.03562186285853386
+0.03534872084856033 7.32430926291272e-05 0.0010297629050910473 0.010370326228439808 0.04579615220427513
+0.0394585058093071 0.00036138982977718115 0.0006610419368371367 0.010033673606812954 0.014501040801405907
+0.04124537110328674 0.00011127736797789112 0.001443652668967843 0.008312865160405636 0.0350780263543129
+0.028216343373060226 0.0006118547171354294 0.001007258309982717 0.008776489645242691 0.023746822029352188
+0.03235223889350891 0.0016638683155179024 0.0007509543211199343 0.020754499360919 0.023071810603141785
+0.026982299983501434 0.0013731892686337233 0.0007287113694474101 0.02274342067539692 0.02911531738936901
+0.0325920395553112 0.0002390034351265058 0.000419309624703601 0.056328825652599335 0.02582451142370701
+0.045319564640522 0.0007611095206812024 0.00031389802461490035 0.014309210702776909 0.06351057440042496
+0.03696737065911293 0.002029474824666977 0.0003625084937084466 0.018030205741524696 0.03589024767279625
+0.03498459234833717 0.00026050559245049953 0.0007779723382554948 0.013182658702135086 0.02776595950126648
+0.04247545823454857 0.00027551857056096196 0.002142396755516529 0.007524256594479084 0.03444695100188255
+0.034961726516485214 0.0001130139353335835 0.00011493994679767638 0.007147859316319227 0.012654050253331661
+0.04040026292204857 0.000721198448445648 0.0016049178084358573 0.014336680993437767 0.0402882881462574
+0.044892169535160065 0.0002613920660223812 0.00031556954490952194 0.0161574799567461 0.05321158096194267
+0.030117090791463852 0.002928395988419652 0.0013179805828258395 0.009554035030305386 0.04514637961983681
+0.03975153714418411 0.0005664521013386548 0.00043179281055927277 0.01048339530825615 0.05378352478146553
+0.024762682616710663 0.001384430448524654 0.0008682649349793792 0.011635493487119675 0.016351191326975822
+0.03704534098505974 0.0015679936623200774 0.00045674023567698896 0.012025860138237476 0.031308624893426895
+0.03238862007856369 5.642689939122647e-05 0.00034786557080224156 0.013671975582838058 0.010090285912156105
+0.03094073198735714 9.595745359547436e-05 0.00016214627248700708 0.012531091459095478 0.03813917562365532
+0.03439633548259735 8.35641985759139e-05 0.0008245277567766607 0.026584787294268608 0.04398134723305702
+0.038902752101421356 0.00048216906725429 0.0003962439077440649 0.01919347420334816 0.05757053196430206
+0.03677123039960861 0.0008929138421081007 0.0003643418604042381 0.009010929614305496 0.05441956967115402
+0.03382015973329544 8.264935604529455e-05 0.00018383321003057063 0.010690883733332157 0.03578845039010048
+0.03343404084444046 0.0007422277121804655 0.0006958140875212848 0.015231619589030743 0.019840314984321594
+0.020623790100216866 4.94185987918172e-05 0.00029492471367120743 0.010897515341639519 0.006901959422975779
+0.0324813537299633 0.0002074206859106198 0.0012584731448441744 0.008191483095288277 0.024373071268200874
+0.02945864573121071 0.00024252694856841117 0.00015570659888908267 0.012180174700915813 0.03616315498948097
+0.024402480572462082 0.0006513078114949167 0.0006174268783070147 0.02145846001803875 0.010602052323520184
+0.03274904936552048 0.00011740736226784065 0.00043985663796775043 0.020116761326789856 0.02426845207810402
+0.03363513574004173 0.0008785808458924294 0.0010846866061910987 0.013638079166412354 0.048755943775177
+0.02027418091893196 9.046645573107526e-05 0.00037236197385936975 0.015876907855272293 0.017102980986237526
+0.03530639410018921 0.00044894154416397214 0.0004696012183558196 0.015486916527152061 0.04651317000389099
+0.02846711128950119 0.00022319788695313036 0.00011593758972594514 0.010384473018348217 0.03244687244296074
+0.03361354023218155 4.41860465798527e-05 0.0002723008510656655 0.011514122597873211 0.027687976136803627
+0.03535497188568115 8.955311932368204e-05 0.0002821618691086769 0.007472474593669176 0.015519341453909874
+0.02854178473353386 0.00021888526680413634 0.001454589539207518 0.012124194763600826 0.02149869315326214
+0.029392080381512642 0.00021144942729733884 0.0007558821234852076 0.01274319738149643 0.031407199800014496
+0.02473439648747444 0.0013018220197409391 0.00016475594020448625 0.013773564249277115 0.01905600167810917
+0.02380647324025631 0.0006637226906605065 0.00027112202951684594 0.01342469546943903 0.01436891034245491
+0.031806580722332 6.150185072328895e-05 0.00043200835352763534 0.021727239713072777 0.03398243337869644
+0.03796708956360817 0.00019264167349319905 0.00022722322319168597 0.01946997083723545 0.05551822483539581
+0.02673763781785965 0.00020399531058501452 0.0009851506911218166 0.014151530340313911 0.027981435880064964
+0.02721267379820347 0.0002484898141119629 0.0012598801404237747 0.03483695536851883 0.012779629789292812
+0.025371912866830826 0.0005993361119180918 0.00043888724758289754 0.010959573090076447 0.02251742221415043
+0.03344542905688286 0.0007871855050325394 0.00015479227295145392 0.014062720350921154 0.013161945156753063
+0.02548164688050747 7.418281893478706e-05 0.00033364101545885205 0.008818376809358597 0.01854347623884678
+0.03304423391819 0.0008751288405619562 0.00016967521514743567 0.011117422953248024 0.03438117355108261
+0.03266111761331558 0.0013969041174277663 0.00025519830523990095 0.017024165019392967 0.04439703747630119
+0.03476167097687721 0.0005020094686187804 0.00036559641011990607 0.01894247718155384 0.054079070687294006
+0.03433160111308098 8.511180203640833e-05 0.00023733744455967098 0.014382691122591496 0.05052122473716736
+0.03237612545490265 0.0001668964687269181 0.00011338587501086295 0.010181150399148464 0.029160136356949806
+0.03437873348593712 0.0003609458508435637 0.0001779649028321728 0.011353005655109882 0.022845609113574028
+0.04357296973466873 0.00039923086296766996 0.00019662801059894264 0.0060869622975587845 0.01255084853619337
+0.02700859308242798 0.0003269191656727344 0.00020029596635140479 0.009714063256978989 0.0046103098429739475
+0.027352700009942055 0.00012311858881730586 0.0003978646418545395 0.01669280044734478 0.025412846356630325
+0.022708987817168236 0.0005839364603161812 0.0006863480084575713 0.02414826676249504 0.02052675373852253
+0.03223339468240738 0.0015187684912234545 0.0009291787282563746 0.017687732353806496 0.04208240658044815
+0.0325198769569397 0.0003296410432085395 0.0004473849548958242 0.01402535755187273 0.04871082305908203
+0.033655524253845215 0.000395463575841859 0.0002186880010413006 0.010089931078255177 0.0419585295021534
+0.031792812049388885 0.0007115596672520041 0.0001547778520034626 0.009205375798046589 0.01779652014374733
+0.03224306181073189 0.0004344213812146336 0.000245604693191126 0.008319806307554245 0.007917470298707485
+0.03548411279916763 0.00021694351744372398 0.00010522552474867553 0.00778436241671443 0.025637302547693253
+0.024757809937000275 0.00018337441724725068 0.0006279240478761494 0.0098866643384099 0.009976808913052082
+0.03027530387043953 0.0003091472899541259 0.0006933660479262471 0.016657685860991478 0.04444364458322525
+0.04392886906862259 0.0017219248693436384 0.00031358387786895037 0.015473373234272003 0.07804466038942337
+0.030059199780225754 0.0009327337611466646 0.00041459608473815024 0.025088271126151085 0.03144749999046326
+0.0367787629365921 0.00010010922414949164 0.00032026751432567835 0.013675831258296967 0.03609638661146164
+0.026069769635796547 0.0005391819286160171 0.000131034612422809 0.011105244979262352 0.012519512325525284
+0.03745437413454056 0.0009933793917298317 0.00010569175356067717 0.014113513752818108 0.006284252740442753
+0.029718194156885147 0.00048123832675628364 0.00011489504686323926 0.01032260898500681 0.007500171195715666
+0.02739384025335312 5.660573151544668e-05 0.00011707576777553186 0.009375380352139473 0.028046725317835808
+0.029143063351511955 0.0003312659391667694 0.00018716305203270167 0.025870412588119507 0.030235933139920235
+0.030978895723819733 0.000899157312233001 0.0003528802772052586 0.01989673636853695 0.04292534664273262
+0.030534029006958008 0.0003113538841716945 0.0003517586155794561 0.011659018695354462 0.03633064031600952
+0.025291891768574715 0.00042059345287270844 0.00023329445684794337 0.011006200686097145 0.025220835581421852
+0.02977193519473076 0.0007064478122629225 0.00023770553525537252 0.012425610795617104 0.023656342178583145
+0.03950700908899307 5.819835132570006e-05 0.0003677847853396088 0.01185321994125843 0.019081247970461845
+0.0356583297252655 0.00017989447223953903 0.0006435788236558437 0.01305090170353651 0.022844813764095306
+0.036177150905132294 0.0002686660154722631 0.0002652159018907696 0.01136934757232666 0.03144725784659386
+0.026535240933299065 0.0005183489993214607 0.0015105698257684708 0.01908057928085327 0.017754819244146347
+0.03153355419635773 0.0003747751470655203 0.00021616565936710685 0.01742967776954174 0.025852354243397713
+0.037479523569345474 0.0018470250070095062 0.0006487021455541253 0.010080056264996529 0.03836167976260185
+0.027496783062815666 0.00016861195035744458 0.00021913259115535766 0.019885199144482613 0.03325049206614494
+0.029747625812888145 0.0005436437204480171 0.0006577487220056355 0.02241024561226368 0.03428546339273453
+0.025994550436735153 0.0001152264594566077 0.0002844251866918057 0.01457950659096241 0.027378948405385017
+0.042303040623664856 9.404095908394083e-05 0.0011964536970481277 0.011086086742579937 0.04885898903012276
+0.03126160427927971 0.0006097666337154806 9.038789721671492e-05 0.008192530833184719 0.032687894999980927
+0.03336995095014572 0.0011537466198205948 0.0008677202858962119 0.014043770730495453 0.0034463482443243265
+0.025617239996790886 0.00013353486428968608 0.00015951460227370262 0.006711557041853666 0.009138461202383041
+0.027416866272687912 0.0012846743920817971 0.00021043888409622014 0.007813284173607826 0.02453633025288582
+0.03756486997008324 0.0003586797683965415 0.0007197769009508193 0.019047865644097328 0.05454452335834503
+0.026319770142436028 0.00015424138109665364 0.0008240310125984251 0.03255108371376991 0.025041749700903893
+0.02450704760849476 0.0005858954391442239 0.00077830720692873 0.010529763996601105 0.024003952741622925
+0.02703460119664669 8.12159851193428e-05 0.0011483628768473864 0.009125550277531147 0.022777562960982323
+0.029293444007635117 0.0020492018666118383 0.00018334352353122085 0.009395976550877094 0.017136717215180397
+0.030657432973384857 0.002212357008829713 0.0004017730534542352 0.009305854327976704 0.014155391603708267
+0.02757689543068409 7.002685742918402e-05 0.0005015020142309368 0.00996847078204155 0.009150074794888496
+0.022167131304740906 0.0010231512133032084 0.0008248921949416399 0.009485929273068905 0.01841791532933712
+0.021645454689860344 0.00011385257676010951 0.0005777858314104378 0.015254837460815907 0.016200225800275803
+0.030536407604813576 9.866469190455973e-05 0.001419863081537187 0.017116229981184006 0.0386492982506752
+0.02446819841861725 0.0003352236235514283 0.0001526601699879393 0.010127339512109756 0.02506220154464245
+0.03893096745014191 0.0007036080933175981 0.0006691860617138445 0.008584952913224697 0.059981122612953186
+0.0277804397046566 0.00013138899521436542 0.0001971337478607893 0.011511024087667465 0.03781113401055336
+0.03489760309457779 0.0009906931081786752 0.0002606325433589518 0.015700846910476685 0.03774626925587654
+0.020274750888347626 0.00017324155487585813 0.00039243543869815767 0.010177413001656532 0.010446176864206791
+0.03578684478998184 4.383882333058864e-05 0.0005250165704637766 0.015001106075942516 0.022113123908638954
+0.02681049332022667 0.0002540639543440193 0.0003877759736496955 0.010154500603675842 0.017938638105988503
+0.02719009667634964 0.00010015760926762596 0.0008596302941441536 0.01409904845058918 0.02022228203713894
+0.025180011987686157 0.0008122037979774177 0.00031076863524504006 0.014771444723010063 0.012082253582775593
+0.027732156217098236 0.0009973429841920733 0.0006411986541934311 0.011602814309298992 0.024442847818136215
+0.02606486901640892 0.00010542404925217852 0.0015812115743756294 0.02745850570499897 0.0304890014231205
+0.02496069297194481 9.73269998212345e-05 0.0006524216150864959 0.02138233371078968 0.02576279081404209
+0.03023599646985531 0.00017468344594817609 0.0015044010942801833 0.009411080740392208 0.042880065739154816
+0.020436078310012817 0.00022515899036079645 0.0008478967356495559 0.007825606502592564 0.0067255799658596516
+0.03628043830394745 0.0020993046928197145 0.00017043633852154016 0.008236636407673359 0.03018856607377529
+0.03702589124441147 0.001318416208960116 0.0003100436006207019 0.006715381983667612 0.012282946147024632
+0.039892103523015976 0.0001401600311510265 0.000146709120599553 0.008151563815772533 0.03815943002700806
+0.030132094398140907 0.0010780316079035401 0.00014601086149923503 0.01062597893178463 0.02904830127954483
+0.03846320882439613 0.00064874003874138 0.0006517937290482223 0.021564757451415062 0.05094646289944649
+0.03166189044713974 0.0003583889047149569 0.00054294423898682 0.025576304644346237 0.03413335233926773
+0.040611229836940765 0.00015984574565663934 0.0002401103120064363 0.014573466032743454 0.04299319162964821
+0.028384320437908173 0.00012620884808711708 0.00019495609740260988 0.01560209784656763 0.027283206582069397
+0.0319543331861496 0.0005094129010103643 0.0002893338387366384 0.02019539661705494 0.031695276498794556
+0.030327802523970604 0.0001503534149378538 0.000695485679898411 0.013407551683485508 0.03150486573576927
+0.03108028694987297 5.2613148000091314e-05 0.0010693501681089401 0.01101811695843935 0.01959017477929592
+0.02773822657763958 0.00031237691291607916 0.00012436964607331902 0.007072733249515295 0.025965958833694458
+0.03132132068276405 0.0005482378182932734 0.0005897556548006833 0.00807195995002985 0.012005174532532692
+0.030041003599762917 0.00014491527690552175 0.00018331345927435905 0.008913793601095676 0.01403117086738348
+0.029658403247594833 0.0011534062214195728 0.00014421525702346116 0.005149621516466141 0.032940078526735306
+0.020373284816741943 0.0002122727019013837 0.0005795548786409199 0.01189501490443945 0.01830906607210636
+0.037037983536720276 8.504926518071443e-05 0.0010939757339656353 0.01955609954893589 0.0512051098048687
+0.02726544626057148 0.00025725940940901637 0.0003449406358413398 0.012979929335415363 0.031429193913936615
+0.025461606681346893 0.00020806209067814052 0.0008344216039404273 0.012458919547498226 0.02957787550985813
+0.02095872536301613 0.0005303659127093852 0.00025950337294489145 0.014918796718120575 0.015374619513750076
+0.026118651032447815 0.0011493596248328686 0.00018922016897704452 0.011830933392047882 0.020459569990634918
+0.03430046886205673 0.0002102382859447971 0.00016955113096628338 0.012351363897323608 0.0036841221153736115
+0.04962697997689247 0.00014236582501325756 0.00027612358098849654 0.00892618577927351 0.042517006397247314
+0.03198399394750595 6.77607677062042e-05 0.0005415066261775792 0.016113117337226868 0.03365588188171387
+0.03587358444929123 0.00010815467248903587 0.0012134646531194448 0.03150582313537598 0.04349713772535324
+0.038035474717617035 0.0014823536621406674 0.0003184037341270596 0.024613652378320694 0.060299646109342575
+0.0235780980437994 0.0012427382171154022 0.0012085868511348963 0.013972839340567589 0.024434978142380714
+0.026841890066862106 0.00043854943942278624 0.00010143012332264334 0.009427185170352459 0.028844717890024185
+0.028480511158704758 0.0032171946950256824 0.0006305637652985752 0.011922581121325493 0.013271601870656013
+0.03774336352944374 0.0006170897395350039 0.00017967850726563483 0.010512125678360462 0.01717417500913143
+0.03988929092884064 0.0002974133822135627 0.0012261407682672143 0.012136152945458889 0.01862805150449276
+0.03231888636946678 0.00012795247312169522 0.00012640775821637362 0.0077803474850952625 0.019735028967261314
+0.022823091596364975 0.0003777960955630988 0.0018797247903421521 0.011735501699149609 0.012514731846749783
+0.022873839363455772 0.00014723198546562344 0.0011230326490476727 0.015240584500133991 0.020318059250712395
+0.02985834889113903 0.001936319749802351 0.00046354110236279666 0.016330063343048096 0.03515126183629036
+0.037043940275907516 0.0014206591295078397 0.0009241990628652275 0.014551661908626556 0.05355053022503853
+0.038410525768995285 0.00021360143728088588 0.00031466694781556726 0.012623812071979046 0.05415090173482895
+0.027140090242028236 0.0015993185807019472 0.0004052880685776472 0.012067591771483421 0.013677822425961494
+0.023495791479945183 0.0004490251885727048 0.0004944807733409107 0.01093247439712286 0.0035861025098711252
+0.026231199502944946 7.769342482788488e-05 0.0007924191886559129 0.00974484533071518 0.007007373031228781
+0.026488181203603745 0.00036042288411408663 0.00024208668037317693 0.017821935936808586 0.015864916145801544
+0.023794950917363167 0.00010102906526299194 0.0005772889708168805 0.015708137303590775 0.009346814826130867
+0.020379740744829178 0.0007963653770275414 0.0003844529273919761 0.01190421637147665 0.007986786775290966
+0.03702854365110397 0.0010095384204760194 0.0002628501388244331 0.035122647881507874 0.03937723860144615
+0.04053984954953194 9.859801502898335e-05 0.0009085815399885178 0.01743277721107006 0.07117896527051926
+0.03882891684770584 0.00014552132051903754 0.001065908814780414 0.023428168147802353 0.038136061280965805
+0.016470644623041153 0.0002892091288231313 0.0010222982382401824 0.01207643561065197 0.0030397900845855474
+0.025341937318444252 7.174186612246558e-05 0.0013369352091103792 0.015020392835140228 0.019177071750164032
+0.026835937052965164 0.0017558276886120439 0.0002549665223341435 0.006697813980281353 0.024250976741313934
+0.029567545279860497 0.002063994063064456 0.0008480569231323898 0.010401973500847816 0.012313475832343102
+0.026843708008527756 0.0002640538732521236 0.0004211766936350614 0.00918613001704216 0.024663660675287247
+0.027512377128005028 0.0020998474210500717 0.0013057474279776216 0.006191635970026255 0.025108804926276207
+0.024686390534043312 8.278861787403002e-05 0.0004263849405106157 0.01279783807694912 0.02741493284702301
+0.03208309784531593 0.0006508352817036211 0.0023789897095412016 0.01834877021610737 0.02797716110944748
+0.025147266685962677 0.00014788954285904765 0.000158209542860277 0.01299036294221878 0.012009014375507832
+0.023563126102089882 0.0012230309657752514 0.0012721604434773326 0.00942972768098116 0.015567610040307045
+0.02344675362110138 5.3976691560819745e-05 0.00017711705004330724 0.011665108613669872 0.013608751818537712
+0.031276024878025055 0.0015752757899463177 0.0007534699980169535 0.014291149564087391 0.02164117805659771
+0.024791140109300613 0.0004005847149528563 0.000196684108232148 0.013893187046051025 0.021840162575244904
+0.0301185455173254 0.000367212196579203 0.0009081057505682111 0.010006733238697052 0.04110568016767502
+0.03593233972787857 0.00019703592988662422 0.0003535948053468019 0.015069886110723019 0.04299892485141754
+0.03578699752688408 0.00011614807590376586 0.0003761777770705521 0.014124415814876556 0.0433083102107048
+0.027464168146252632 9.527932706987485e-05 0.00039003038546070457 0.008490626700222492 0.022795792669057846
+0.03310547024011612 6.150501576485112e-05 0.0002594149555079639 0.008695228956639767 0.02187913842499256
+0.0181712843477726 7.951157749630511e-05 0.00020263699116185308 0.009484419599175453 0.006928388494998217
+0.02688605710864067 0.00014236278366297483 0.00022871489636600018 0.017673302441835403 0.029775310307741165
+0.03182992339134216 0.00018633640138432384 0.0002513011859264225 0.01219403836876154 0.035826511681079865
+0.023172665387392044 6.9028879806865e-05 0.00020946026779711246 0.01450380776077509 0.017466386780142784
+0.024921488016843796 0.0003461168089415878 0.00016968004638329148 0.013205948285758495 0.026000842452049255
+0.023401407524943352 0.00021626793022733182 0.000266915449174121 0.013757128268480301 0.008427069522440434
+0.025599494576454163 6.485676567535847e-05 0.0003425180329941213 0.009516154415905476 0.023590875789523125
+0.027882583439350128 7.270663627423346e-05 0.00023559550754725933 0.011713553220033646 0.02739952690899372
+0.02715734764933586 6.877132545923814e-05 0.00011564174928935245 0.012181008234620094 0.02260977402329445
+0.029858676716685295 5.8978293964173645e-05 9.095367568079382e-05 0.008872494101524353 0.03254779800772667
+0.0361345075070858 0.0001129178999690339 0.00014748016837984324 0.010456360876560211 0.05418413132429123
+0.02218751236796379 4.144967897445895e-05 0.00011051798355765641 0.009332656860351562 0.013509736396372318
+0.027097878977656364 7.174593338277191e-05 7.76034576119855e-05 0.008842536248266697 0.007697003893554211
+0.022217828780412674 5.3448049584403634e-05 0.000120206517749466 0.009313863702118397 0.013023270294070244
+0.026864442974328995 0.00014145734894555062 0.0002587384369689971 0.011327330954372883 0.03552001342177391
+0.02845773473381996 0.0001164411660283804 0.0006637938786298037 0.019526604562997818 0.03466988354921341
+0.025408634915947914 0.000447404250735417 0.0005260399193502963 0.02113182097673416 0.03304391726851463
+0.021276354789733887 0.0006529713282361627 0.0008317311876453459 0.011974988505244255 0.010248667560517788
+0.0314827561378479 0.00019898952450603247 0.00017803491209633648 0.007751293946057558 0.033955879509449005
+0.028908755630254745 0.001361989532597363 0.0010646788869053125 0.011514736339449883 0.028808623552322388
+0.027435997501015663 6.587553798453882e-05 0.00011241967877140269 0.011526210233569145 0.02370476722717285
+0.019538607448339462 0.001195459975861013 0.0013478185283020139 0.006398223340511322 0.011865685693919659
+0.023038407787680626 0.00016541096556466073 0.0002899064274970442 0.006125601474195719 0.01264035515487194
+0.02076919749379158 0.0010094785830006003 0.001184464548714459 0.009575976058840752 0.014735568314790726
+0.028732113540172577 0.0004002335772383958 0.0006326682632789016 0.01075471006333828 0.02692929469048977
+0.03212394937872887 0.0007733421516604722 0.0007161968387663364 0.013265007175505161 0.0424710214138031
+0.03030385822057724 0.0008174798567779362 0.0008688344387337565 0.014712565578520298 0.02749568223953247
+0.03234213963150978 0.00022301626449916512 0.000347991066519171 0.011155224405229092 0.044315263628959656
+0.016853852197527885 0.0010666385060176253 0.0008067151647992432 0.01407098863273859 0.0013157021021470428
+0.01772300712764263 3.759090759558603e-05 0.00015289774455595762 0.012336892075836658 0.007432577665895224
+0.029924282804131508 0.00047419543261639774 0.0006584688671864569 0.013373339548707008 0.021004928275942802
+0.035714004188776016 6.630329880863428e-05 0.00012968230294063687 0.012268676422536373 0.033444132655858994
+0.025540746748447418 0.00019209308084100485 0.00046021080925129354 0.015260309912264347 0.023117462173104286
+0.026118462905287743 3.561102130333893e-05 0.00015905339387245476 0.010189777240157127 0.01750301755964756
+0.022352177649736404 9.453583334106952e-05 0.00016799299919512123 0.008046925067901611 0.018084578216075897
+0.024372760206460953 3.652209852589294e-05 0.00020882152603007853 0.021770287305116653 0.014996159821748734
+0.02893691323697567 5.225001586950384e-05 0.00038294336991384625 0.01964794471859932 0.0347500704228878
+0.020065050572156906 0.0001467048132326454 0.0004304070316720754 0.010850814171135426 0.020601792261004448
+0.032215867191553116 4.516173794399947e-05 0.0002755385357886553 0.007080380339175463 0.025833718478679657
+0.02415790781378746 0.0008010324672795832 0.0003270755405537784 0.009342826902866364 0.01678246445953846
+0.022961057722568512 0.0003117657615803182 0.00019592490571085364 0.010163023136556149 0.006896728649735451
+0.02136952057480812 0.00039565397310070693 0.00041189571493305266 0.008188137784600258 0.01725953444838524
+0.033029358834028244 0.000496442022267729 0.00045423093251883984 0.021946880966424942 0.03220994025468826
+0.0296668391674757 7.831244147382677e-05 0.0003797328972723335 0.009319571778178215 0.03434666618704796
+0.028965463861823082 0.0002852763282135129 0.0007882025674916804 0.01595100201666355 0.03832618147134781
+0.033409710973501205 0.0001392580015817657 0.00020787626272067428 0.013737278990447521 0.0477614589035511
+0.016480296850204468 0.0005525126471184194 0.0010457284515723586 0.010846848599612713 0.0036136754788458347
+0.01760069467127323 9.565707296133041e-05 0.00020664015028160065 0.00978896114975214 0.00425475649535656
+0.031527478247880936 0.0015369803877547383 0.0007994546904228628 0.007250514347106218 0.039476245641708374
+0.025942377746105194 0.0006822997238487005 0.0008690913091413677 0.01073041558265686 0.01778707653284073
+0.024250918999314308 0.0006738360971212387 0.00028526101959869266 0.013300947844982147 0.003914260771125555
+0.02525470033288002 0.0019319219281896949 0.0008111662464216352 0.009467143565416336 0.015002373605966568
+0.03249247372150421 9.192246943712234e-05 0.0005859577795490623 0.01249566487967968 0.04018422216176987
+0.02215971052646637 0.0006411559297703207 0.0020031335297971964 0.022839652374386787 0.017286835238337517
+0.016550879925489426 6.927579670445994e-05 0.00033442871063016355 0.01531991921365261 0.00802350603044033
+0.027368778362870216 0.000825501570943743 0.0022933899890631437 0.010865761898458004 0.02013779804110527
+0.021443724632263184 6.226962432265282e-05 0.0006642176886089146 0.008827357552945614 0.015818651765584946
+0.02159843221306801 0.0021293421741575003 0.00039349275175482035 0.009121122770011425 0.01444430835545063
+0.026906754821538925 0.000827046693302691 0.00017542294517625123 0.006946816109120846 0.021702367812395096
+0.033965304493904114 0.00033807489671744406 0.000557278748601675 0.009403274394571781 0.032718975096940994
+0.02217107266187668 0.0006583372596651316 0.00024582690093666315 0.009343014098703861 0.014258881099522114
+0.02134380303323269 3.764974826481193e-05 0.0006842771545052528 0.008365037851035595 0.01669837348163128
+0.024163346737623215 9.503422188572586e-05 0.0011107450118288398 0.012552876025438309 0.02639508806169033
+0.021417170763015747 0.0002353991731069982 0.00010079846833832562 0.01246944721788168 0.009924323298037052
+0.022126857191324234 0.00050537841161713 0.0002760498027782887 0.011868387460708618 0.011113389395177364
+0.031032036989927292 6.626216054428369e-05 0.0001567322324262932 0.014112681150436401 0.03739028424024582
+0.03159279376268387 0.000523220282047987 0.00025915581500157714 0.014890510588884354 0.049617186188697815
+0.0308486707508564 0.00010379607556387782 0.00015187326062005013 0.009192749857902527 0.03993342071771622
+0.030575279146432877 0.00011601289588725194 0.0005592844681814313 0.009996662847697735 0.03352033346891403
+0.02385820262134075 4.058656850247644e-05 0.00019853193953167647 0.00837874785065651 0.017416780814528465
+0.02716989815235138 0.0002935791853815317 0.0003780521219596267 0.006794413551688194 0.010465322993695736
+0.029946863651275635 8.122621511574835e-05 0.0003121902409475297 0.008072542026638985 0.027379432693123817
+0.02578810043632984 0.000375022878870368 0.00016816302377264947 0.010869606398046017 0.023967407643795013
+0.02549479901790619 0.0006171765271574259 0.00034482087357901037 0.019561290740966797 0.030564166605472565
+0.027230482548475266 0.00015198941400740296 0.00026517879450693727 0.015415900386869907 0.027162935584783554
+0.024501994252204895 0.00025995876058004797 0.00011088125756941736 0.010682865977287292 0.01552570890635252
+0.029104549437761307 0.0008468463202007115 0.0001497803023084998 0.010608794167637825 0.03590865060687065
+0.024159222841262817 0.00030642046476714313 0.00019918897305615246 0.01621864177286625 0.02231973595917225
+0.0303343553096056 0.00019575793703552336 0.00038126856088638306 0.015314437448978424 0.0378284677863121
+0.03186555951833725 0.0002934909425675869 0.000249274104135111 0.009991475380957127 0.03565610572695732
+0.028405694290995598 9.809311450226232e-05 0.0005323721561580896 0.024338679388165474 0.011814115568995476
+0.03699508309364319 0.00022854487178847194 0.0012679054634645581 0.011513853445649147 0.024862755089998245
+0.026684409007430077 0.00017858200590126216 0.00032900599762797356 0.014326178468763828 0.026492638513445854
+0.025481685996055603 0.0009182339999824762 0.0004347909998614341 0.022705862298607826 0.023006629198789597
+0.029208693653345108 0.00042967431363649666 0.0005367753328755498 0.023809393867850304 0.038767531514167786
+0.028022846207022667 0.00016993730969261378 0.0002766810357570648 0.01813492737710476 0.036572568118572235
+0.03530053794384003 0.0008620748412795365 0.00010779813601402566 0.0163913331925869 0.036297865211963654
+0.03293471038341522 0.0004671339993365109 0.00015365863509941846 0.011986695230007172 0.03762257099151611
+0.028922779485583305 6.991611735429615e-05 0.0004145806306041777 0.012885811738669872 0.0055316900834441185
+0.02584415301680565 0.00010107927664648741 0.0002722470089793205 0.009022452868521214 0.015911292284727097
+0.03561829775571823 9.773196507012472e-05 0.0003440200234763324 0.0072456831112504005 0.008313140831887722
+0.027042273432016373 0.00012628862168639898 0.0010002361377701163 0.03726411983370781 0.009069702588021755
+0.031333547085523605 0.00040152386645786464 0.00020912195031996816 0.01104989554733038 0.023132730275392532
+0.029559345915913582 0.001075657899491489 0.00021108865621499717 0.011651215143501759 0.033913448452949524
+0.02068764716386795 0.0002179605362471193 0.00026894931215792894 0.02075287140905857 0.01886548288166523
+0.025038443505764008 0.0002341126964893192 0.00042119339923374355 0.015362289734184742 0.03309480845928192
+0.02199740521609783 0.00028255381039343774 0.00038824789226055145 0.01087616290897131 0.01598449796438217
+0.02181524783372879 0.00023006339324638247 0.0005933582433499396 0.010394968092441559 0.00304942368529737
+0.02774461917579174 0.00037773646181449294 0.00028292389470152557 0.010687360540032387 0.028931204229593277
+0.018890928477048874 0.0004290177603252232 0.0005882024997845292 0.012279413640499115 0.004044848028570414
+0.02268502488732338 0.00019582323147915304 0.00021835684310644865 0.01015468966215849 0.021069088950753212
+0.026526279747486115 0.001128366682678461 0.00028822501190006733 0.009063917212188244 0.010721693746745586
+0.028060782700777054 0.00032804717193357646 0.00022202574473340064 0.00869755819439888 0.03129231184720993
+0.027263911440968513 0.00019693812646437436 0.0010622430127114058 0.015829283744096756 0.017912287265062332
+0.027192657813429832 6.470915832323954e-05 0.0006198939518071711 0.013493859209120274 0.030715523287653923
+0.03153948113322258 0.00037407735362648964 0.0011924540158361197 0.011806812137365341 0.0432337149977684
+0.023636767640709877 5.2267048886278644e-05 0.0006468259380199015 0.015227748081088066 0.02780415676534176
+0.024105612188577652 0.0009450906072743237 0.00040794615051709116 0.01222152728587389 0.02207281067967415
+0.029556719586253166 0.0005962920258753002 0.0001278528943657875 0.010663416236639023 0.01881648227572441
+0.023508962243795395 6.160183693282306e-05 0.00018839802942238748 0.007472295314073563 0.014003348536789417
+0.025180678814649582 5.9335674450267106e-05 0.0002904014545492828 0.009724797680974007 0.021237920969724655
+0.023125197738409042 0.00026435410836711526 0.0005829824949614704 0.013713722117245197 0.01978437788784504
+0.02548963390290737 0.0008085091249085963 0.00016207917360588908 0.008811780251562595 0.02679489739239216
+0.022764766588807106 0.0003201451327186078 0.00020221981685608625 0.007641082163900137 0.022348813712596893
+0.024990148842334747 0.00012928337673656642 0.0003584126243367791 0.012536908499896526 0.026422979310154915
+0.03244588151574135 0.0001155513382400386 7.493457815144211e-05 0.01140904426574707 0.024938171729445457
+0.028096389025449753 6.256293272599578e-05 0.0009395019733346999 0.009877842850983143 0.00796960387378931
+0.02103429287672043 0.00032895695767365396 0.00030175907886587083 0.0058459010906517506 0.008496267721056938
+0.026817014440894127 0.0013428087113425136 0.0006758598610758781 0.010745111852884293 0.022475365549325943
+0.021299239248037338 7.796693535055965e-05 0.00026221442385576665 0.020750446245074272 0.009104720316827297
+0.029692132025957108 0.00244534807279706 0.0018470797222107649 0.011012987233698368 0.019499775022268295
+0.02674316242337227 0.0006150712724775076 0.0005261955666355789 0.007312473375350237 0.03481011092662811
+0.04210590198636055 0.001506088417954743 0.0020720732863992453 0.015532509423792362 0.0674881860613823
+0.023356080055236816 0.0006489831139333546 0.0008027803269214928 0.025381436571478844 0.017447028309106827
+0.03147425130009651 0.0019364471081644297 0.003836282528936863 0.012237265706062317 0.024502096697688103
+0.037828899919986725 0.00016681681154295802 0.0012377414386719465 0.013411308638751507 0.03529220074415207
+0.030227849259972572 0.005262391176074743 0.00448774266988039 0.008484822697937489 0.02793298102915287
+0.033537764102220535 0.0016322563169524074 0.003117106854915619 0.01655348390340805 0.03779938071966171
+0.033231575042009354 0.004650206305086613 0.0021298991050571203 0.01007716916501522 0.009993723593652248
+0.027184132486581802 0.004482210148125887 0.0019434703281149268 0.010643635876476765 0.018308935686945915
+0.02393874153494835 0.001199071411974728 0.0035272189415991306 0.01779760606586933 0.01419077068567276
+0.02903164178133011 0.002819937188178301 0.004790748003870249 0.024870647117495537 0.033005885779857635
+0.03580557182431221 0.0010320055298507214 0.0012667336268350482 0.013015025295317173 0.052782244980335236
+0.02650861069560051 0.003028700826689601 0.005010706372559071 0.012088032439351082 0.02239827625453472
+0.0191884133964777 0.0002834160695783794 0.00023023698304314166 0.010719281621277332 0.004655987489968538
+0.02559533528983593 0.004270350094884634 0.0009984745411202312 0.00634206086397171 0.00032942843972705305
+0.032917723059654236 0.0012216215254738927 0.0001776943972799927 0.00473334826529026 0.02332191728055477
+0.04122502729296684 0.0002735629095695913 0.000340028025675565 0.006620194297283888 0.041443075984716415
+0.0333498977124691 0.0012267486890777946 0.00040100523619912565 0.019254524260759354 0.015777505934238434
+0.030711248517036438 0.0006659578066319227 0.00031229297746904194 0.014543331228196621 0.01500982791185379
+0.02326829358935356 7.181792898336425e-05 0.0005961440620012581 0.013623406179249287 0.007106791716068983
+0.019893502816557884 0.0003224833635613322 0.0010213612113147974 0.016398539766669273 0.014368433505296707
+0.03134512901306152 5.8371129853185266e-05 0.00037185504334047437 0.026842352002859116 0.03707628324627876
+0.03986155986785889 0.0007082271622493863 0.0018178478349000216 0.01996924728155136 0.059887249022722244
+0.019494276493787766 7.261142309289426e-05 0.0012690480798482895 0.01937849260866642 0.0049392106011509895
+0.023078519850969315 0.0018407825846225023 0.00025180811644531786 0.010257975198328495 0.018178824335336685
+0.029966529458761215 0.0025672754272818565 0.0012245282996445894 0.010069457814097404 0.011730764992535114
+0.022863904014229774 5.416577187133953e-05 0.0003728554002009332 0.00770716555416584 0.015660995617508888
+0.020723531022667885 0.002153517445549369 0.00029431842267513275 0.009112481959164143 0.011500176042318344
+0.024045497179031372 0.0012388642644509673 0.00025683187413960695 0.00898510217666626 0.013958020135760307
+0.028434565290808678 0.00013211346231400967 0.002447740873321891 0.014579535461962223 0.026013081893324852
+0.022442063316702843 0.00017213619139511138 0.0018374084029346704 0.021862825378775597 0.019111622124910355
+0.02790820598602295 0.0006535188294947147 0.0010605895658954978 0.014988917857408524 0.02648366428911686
+0.017741059884428978 0.0005213927943259478 0.0023994934745132923 0.009602999314665794 0.009184136055409908
+0.020782817155122757 0.0007984506664797664 0.0004282848094590008 0.009986207820475101 0.012710615992546082
+0.021574223414063454 0.0030301527585834265 0.00018898599955718964 0.005330237094312906 0.013906876556575298
+0.02280360832810402 0.0007066638208925724 0.00014886407006997615 0.0051772198639810085 0.01702071540057659
+0.019077472388744354 0.0001707584160612896 0.000471599109005183 0.00579858897253871 0.012810689397156239
+0.03106461465358734 0.00045272716670297086 0.00012740280362777412 0.011068142019212246 0.02719295211136341
+0.03297824785113335 0.00028612377354875207 0.0005039940006099641 0.010968683287501335 0.04222724586725235
+0.030536774545907974 0.00017926307918969542 0.0010201799450442195 0.01105524506419897 0.04737262800335884
+0.024762475863099098 0.00021026897593401372 0.00068919628392905 0.013836676254868507 0.00759515818208456
+0.020582430064678192 0.0002027575537795201 0.00031331137870438397 0.013698559254407883 0.015118052251636982
+0.02776057831943035 0.00010095768084283918 0.0002464531862642616 0.013780576176941395 0.020418832078576088
+0.03035038150846958 0.0005399552173912525 0.0003323548589833081 0.01529508363455534 0.0303826741874218
+0.02456315979361534 0.0001962645328603685 0.0003448286443017423 0.014324229210615158 0.014841409400105476
+0.020127609372138977 0.00013100421347189695 0.00042446734732948244 0.011466806754469872 0.009310095570981503
+0.02324226126074791 0.00015082255413290113 0.00035191027563996613 0.006906080991029739 0.029223458841443062
+0.030014166608452797 0.00011844748223666102 0.00027974904514849186 0.012341022491455078 0.033131193369627
+0.017618944868445396 6.897420098539442e-05 0.00019871302356477827 0.009402633644640446 0.006408296525478363
+0.02596692368388176 0.00011408094724174589 0.00035571149783208966 0.008532965555787086 0.015360260382294655
+0.022002093493938446 3.0109933504718356e-05 0.0001667182514211163 0.009312228299677372 0.0075711472891271114
+0.03251047432422638 0.00021237583132460713 0.0003077714645769447 0.012365192174911499 0.030154608190059662
+0.02686578780412674 4.9810001655714586e-05 0.00018564070342108607 0.014748058281838894 0.01804169826209545
+0.025553304702043533 0.00022357728448696434 0.0002884432324208319 0.013756290078163147 0.007411369122564793
+0.03187040984630585 8.02234499133192e-05 0.0001302077725995332 0.010864497162401676 0.042526766657829285
+0.023490697145462036 0.0002016743819694966 0.0006765666767023504 0.016191812232136726 0.018523652106523514
+0.025425121188163757 3.627158730523661e-05 0.00016275278176181018 0.011749133467674255 0.015197422355413437
+0.02377082034945488 0.00042781606316566467 0.0004707946500275284 0.008190719410777092 0.009520221501588821
+0.03024732694029808 4.371661270852201e-05 0.00014701875625178218 0.009385433979332447 0.02764067053794861
+0.03534607216715813 0.00043702396214939654 0.00022248548339121044 0.04396624490618706 0.02935902588069439
+0.0271291621029377 0.00020152993965893984 5.029531166655943e-05 0.007697963621467352 0.011103699915111065
+0.023109808564186096 9.729155863169581e-05 0.00044521468225866556 0.007706908043473959 0.01965244486927986
+0.028804197907447815 5.775729368906468e-05 0.00014481185644399375 0.013652044348418713 0.02932647056877613
+0.019577283412218094 8.186508057406172e-05 0.0003408892371226102 0.014106555841863155 0.017199616879224777
+0.03314363211393356 8.39233907754533e-05 0.00019180700473953038 0.008851961232721806 0.05587054789066315
+0.023236986249685287 9.228244016412646e-05 0.00020778836915269494 0.0119321309030056 0.008829636499285698
+0.04012559726834297 0.00010949512216029689 0.0002955430536530912 0.010039939545094967 0.03336496651172638
+0.02832089737057686 0.00010299162386218086 0.000255122606176883 0.009774812497198582 0.013301240280270576
+0.028264867141842842 8.40447100927122e-05 0.00037488018278963864 0.008043062873184681 0.022953886538743973
+0.03886318951845169 0.00017586465401109308 0.0007150596939027309 0.0133455665782094 0.061161983758211136
+0.045483268797397614 0.0003922184114344418 0.0006531636463478208 0.023182209581136703 0.05613813176751137
+0.02226851135492325 0.0006687594577670097 0.0005928427563048899 0.01334704365581274 0.017416832968592644
+0.027485240250825882 5.092989158583805e-05 0.00041440658969804645 0.007257797755300999 0.029338931664824486
+0.02450765110552311 0.0014119971310719848 0.00025201402604579926 0.012751637026667595 0.017413638532161713
+0.01973104476928711 0.0008359382627531886 9.598392352927476e-05 0.010585102252662182 0.012305290438234806
+0.02790963277220726 0.00011409032595111057 0.0007348532672040164 0.008705119602382183 0.006902102380990982
+0.022506659850478172 0.00019641387916635722 0.0003006962360814214 0.006805469281971455 0.005595083814114332
+0.025470905005931854 0.0001480491628171876 0.0007155620842240751 0.00847130548208952 0.0310544203966856
+0.023846279829740524 8.888456068234518e-05 0.0012022483861073852 0.01670585572719574 0.020773611962795258
+0.026332607492804527 0.0009200581698678434 0.000277486105915159 0.012443862855434418 0.025873886421322823
+0.02272765152156353 0.0009245125111192465 0.0004224067961331457 0.008612815290689468 0.021942399442195892
+0.016781892627477646 0.00024325249250978231 0.0002665433567017317 0.011688976548612118 0.010920771397650242
+0.03661288321018219 0.0008590011275373399 0.0002468026359565556 0.010397342965006828 0.04678564518690109
+0.027460550889372826 9.256248449673876e-05 0.0003784997097682208 0.012042030692100525 0.03758382797241211
+0.024164779111742973 8.761281060287729e-05 0.0007375080604106188 0.010286383330821991 0.013389058411121368
+0.026212722063064575 8.31215875223279e-05 0.0001510455331299454 0.008305958472192287 0.018294086679816246
+0.030738594010472298 0.000403090612962842 0.0005248051602393389 0.008852132596075535 0.019660843536257744
+0.025270801037549973 3.0188088203431107e-05 0.0003109943645540625 0.010443835519254208 0.0034311963245272636
+0.02518678456544876 0.0006996799493208528 0.00014200681471265852 0.009871749207377434 0.023835811764001846
+0.024836668744683266 0.0005629394436255097 0.00026691172388382256 0.014348643831908703 0.013955457136034966
+0.02782765030860901 0.00010811660467879847 0.0002522112044971436 0.02178792469203472 0.02766837365925312
+0.024706032127141953 4.226803139317781e-05 0.00019575352780520916 0.011895912699401379 0.0285281240940094
+0.018224501982331276 0.00018803900456987321 0.00028342375298961997 0.00818637479096651 0.01566782221198082
+0.021635936573147774 0.0006071043317206204 6.801170820835978e-05 0.008530888706445694 0.009936200454831123
+0.020497824996709824 0.00021166411170270294 8.61811640788801e-05 0.008110804483294487 0.012041943147778511
+0.029514221474528313 0.00012663131929002702 0.0002316304307896644 0.009270629845559597 0.022767670452594757
+0.021287687122821808 7.100302900653332e-05 0.00010865273361559957 0.007283993065357208 0.012007191777229309
+0.02411743812263012 4.4457061449065804e-05 0.0006647182162851095 0.008566062897443771 0.023513076826930046
+0.02886219695210457 9.252881136490032e-05 0.0005256772856228054 0.008040430024266243 0.04610653221607208
+0.019482364878058434 0.0006833578227087855 9.052696987055242e-05 0.011072801426053047 0.009486669674515724
+0.019381236284971237 0.00042959899292327464 0.00037728738971054554 0.01045681070536375 0.01595921255648136
+0.016137635335326195 0.00031052090344019234 8.675941353430972e-05 0.008448061533272266 0.004721761215478182
+0.02130410075187683 0.0015106771606951952 0.0005263324128463864 0.010046259500086308 0.018807202577590942
+0.0212510097771883 0.0003682476235553622 0.00029231313965283334 0.011155771091580391 0.011729919351637363
+0.022330544888973236 0.0004199454269837588 0.0004707501211669296 0.019549088552594185 0.0023543694987893105
+0.018779968842864037 0.0008796335314400494 0.0006764968857169151 0.011394909583032131 0.004875236190855503
+0.024972356855869293 4.299500506022014e-05 0.00032709931838326156 0.01633596234023571 0.011403088457882404
+0.030054951086640358 0.00030162828625179827 0.0009866687469184399 0.021151985973119736 0.04220344126224518
+0.03663421794772148 0.000175892622792162 0.0002985993050970137 0.020572129637002945 0.05501752346754074
+0.029984891414642334 0.0008919440442696214 0.002273046411573887 0.01151454821228981 0.03303522989153862
+0.026746392250061035 8.983653242466971e-05 0.0007330995867960155 0.007736781146377325 0.0179153960198164
+0.02414909563958645 0.003295178757980466 0.0006447760970331728 0.007564928382635117 0.020951518788933754
+0.02756117470562458 0.0021857875399291515 0.0012584879295900464 0.007889453321695328 0.012179598212242126
+0.01899423822760582 0.0004623626300599426 7.470735727110878e-05 0.008635450154542923 0.0060783689841628075
+0.022627051919698715 0.003109953599050641 0.0003929587546736002 0.009265989996492863 0.011477654799818993
+0.017302270978689194 0.0005080254632048309 0.0008011462050490081 0.011209554970264435 0.007825043983757496
+0.029708661139011383 0.00021480608847923577 0.002802518429234624 0.016939401626586914 0.04319940134882927
+0.028523748740553856 9.33278352022171e-05 0.0005162135348655283 0.024421071633696556 0.04206611216068268
+0.016997652128338814 0.0008930749609135091 0.0016849864041432738 0.014501001685857773 0.011802478693425655
+0.028096653521060944 5.036082802689634e-05 0.0015869734343141317 0.0107249915599823 0.005689411889761686
+0.02105848304927349 0.002373238792642951 8.089817129075527e-05 0.00835482683032751 0.011386081576347351
+0.02534369006752968 0.002738187788054347 0.00026632766821421683 0.007382106967270374 0.018724728375673294
+0.028643203899264336 5.799824430141598e-05 0.00036987365456297994 0.00521283783018589 0.035067442804574966
+0.03419801592826843 0.001374468905851245 0.000820862827822566 0.010686135850846767 0.024087224155664444
+0.022251268848776817 0.0008918749517761171 0.0002250721590826288 0.010655254125595093 0.025668654590845108
+0.03303207829594612 8.96606725291349e-05 0.0007231244817376137 0.008548414334654808 0.045812249183654785
+0.03125326335430145 0.0002692693960852921 0.0009396281675435603 0.008331681601703167 0.04120955243706703
+0.030696334317326546 8.78095452208072e-05 0.00018701510271057487 0.010133511386811733 0.04545743390917778
+0.025059077888727188 0.00031819529249332845 0.0005770965362899005 0.015028764493763447 0.02694983035326004
+0.01525996346026659 4.4990494643570855e-05 0.00024491091608069837 0.013039317913353443 0.00600298261269927
+0.027807706966996193 0.0007307668565772474 0.00020751164993271232 0.010666904039680958 0.02880769595503807
+0.01842709444463253 0.000885364948771894 0.0006850839708931744 0.008290919475257397 0.00043297113734297454
+0.02064530923962593 4.3683092371793464e-05 0.00018945126794278622 0.010504527017474174 0.010772226378321648
+0.028792981058359146 0.0010510842548683286 0.00014258814917411655 0.012203192338347435 0.017603641375899315
+0.03041970543563366 0.0007716991822235286 0.00014819057832937688 0.015612768009305 0.028463123366236687
+0.021500669419765472 5.480119216372259e-05 0.0003941506438422948 0.01688266173005104 0.008146417327225208
+0.022276023402810097 6.339305400615558e-05 0.00018617554451338947 0.012746947817504406 0.02517220377922058
+0.03650001436471939 5.7590634241933e-05 0.00043152313446626067 0.019801098853349686 0.03877648711204529
+0.026262301951646805 2.871737342502456e-05 0.001241235644556582 0.011697882786393166 0.020356034860014915
+0.024179572239518166 0.00029469464789144695 0.0006971328984946012 0.008540130220353603 0.022446217015385628
+0.026590708643198013 0.001400753972120583 0.00013090166612528265 0.006261325906962156 0.02443203330039978
+0.02056014910340309 0.0005372672458179295 0.0003787064924836159 0.007763945497572422 0.008291014470160007
+0.03113393858075142 0.0005071378545835614 0.00012639915803447366 0.009411482140421867 0.022689003497362137
+0.03036387264728546 0.0011097422102466226 7.666906458325684e-05 0.005641999654471874 0.027356408536434174
+0.021351538598537445 8.965165034169331e-05 0.0012217635521665215 0.005805369466543198 0.01258417870849371
+0.01815851405262947 5.9859205066459253e-05 0.00170612009242177 0.010535156354308128 0.012316875159740448
+0.024777796119451523 0.0002903557615354657 0.00016485867672599852 0.01180338580161333 0.023429732769727707
+0.021633446216583252 0.00043900564196519554 0.0006521369214169681 0.011727056466042995 0.008983773179352283
+0.019573137164115906 0.00010103126260219142 0.00040239220834337175 0.014198390766978264 0.013055489398539066
+0.024388227611780167 0.000841519155073911 0.0004382941115181893 0.015032918192446232 0.02210138365626335
+0.028060056269168854 0.00027768214931711555 0.0004875852318946272 0.010542037896811962 0.012822559103369713
+0.021929122507572174 3.5990738979307935e-05 0.0009759000968188047 0.009022556245326996 0.01739971898496151
+0.019765079021453857 4.448267281986773e-05 0.00028534833109006286 0.007921576499938965 0.01708570495247841
+0.028677459806203842 5.749071351601742e-05 7.104208634700626e-05 0.00879676267504692 0.037619054317474365
+0.024079326540231705 8.189190702978522e-05 8.284726209240034e-05 0.009536894038319588 0.017599599435925484
+0.02233891561627388 0.0002091982460115105 9.77271847659722e-05 0.006963054649531841 0.008379240520298481
+0.02331922948360443 6.4150444814004e-05 0.000267581723164767 0.005849442910403013 0.00698462687432766
+0.016293678432703018 0.00026869482826441526 0.0009812423959374428 0.007759062107652426 0.005483156070113182
+0.018611036241054535 4.468828410608694e-05 0.000522472953889519 0.010721206665039062 0.015875374898314476
+0.020750515162944794 0.0005320183699950576 0.0002429519809084013 0.01395442895591259 0.015018372796475887
+0.020441176369786263 0.0003469246148597449 0.0003321921976748854 0.0169361624866724 0.01158605981618166
+0.02037980780005455 5.077814057585783e-05 0.00020142600988037884 0.011470603756606579 0.014327436685562134
+0.027874618768692017 0.00020072523329872638 9.710234735393897e-05 0.010374623350799084 0.028462912887334824
+0.026218056678771973 7.726854528300464e-05 0.00012394053919706494 0.009953144006431103 0.027590632438659668
+0.03414569050073624 5.071537452749908e-05 0.0005391293088905513 0.01025087758898735 0.0419357605278492
+0.03421260789036751 3.273836773587391e-05 0.0004573026963043958 0.009672688320279121 0.04757518693804741
+0.026133526116609573 0.0003167561662849039 9.313213377026841e-05 0.010563836432993412 0.023273972794413567
+0.020338254049420357 0.00040071186958812177 0.00023004486865829676 0.008820011280477047 0.013173704966902733
+0.02131063863635063 6.503217446152121e-05 0.0001230579655384645 0.009538630954921246 0.010168644599616528
+0.014883387833833694 0.00041110446909442544 0.0003534007410053164 0.011405976489186287 0.005510716699063778
+0.020579898729920387 0.00012079264706699178 0.0006233318708837032 0.017770124599337578 0.015617204830050468
+0.020130235701799393 7.452532008755952e-05 0.0005772129516117275 0.015970440581440926 0.022030390799045563
+0.018316183239221573 0.0001528131397208199 0.00026603968581184745 0.014454192481935024 0.017450537532567978
+0.020554767921566963 0.00010458665929036215 0.00028277697856538 0.01064692810177803 0.014900295995175838
+0.019341174513101578 0.00018143146007787436 0.00013824441703036427 0.006955592427402735 0.014638060703873634
+0.020984947681427002 0.0003967838711105287 8.125699969241396e-05 0.006809167098253965 0.01628957875072956
+0.02020888589322567 0.00013779624714516103 0.00013979269715491682 0.004995326045900583 0.004716195166110992
+0.015825731679797173 4.483177326619625e-05 0.00013041852798778564 0.005139966029673815 0.0029952023178339005
+0.02633163332939148 4.9031117669073865e-05 0.0001939588546520099 0.009363071992993355 0.03270625323057175
+0.020013030618429184 0.0001312003587372601 0.0005377219640649855 0.014073883183300495 0.01608048751950264
+0.019055230543017387 0.0005050406325608492 0.00031287624733522534 0.010800456628203392 0.02323298715054989
+0.024191826581954956 0.0001090628866222687 0.00027483084704726934 0.010898351669311523 0.0261320061981678
+0.021153103560209274 0.00046363414730876684 0.0004520276270341128 0.01251970138400793 0.014635083265602589
+0.022245055064558983 0.0003405116149224341 0.00023117268574424088 0.011521706357598305 0.01865851692855358
+0.020845288410782814 7.578245276818052e-05 0.000300006999168545 0.008734076283872128 0.004384973086416721
+0.012856324203312397 9.652077278587967e-05 0.00014325050869956613 0.008114149793982506 0.0020878484938293695
+0.015308580361306667 0.0001311419764533639 0.0004176199436187744 0.010163407772779465 0.007639269344508648
+0.02264557033777237 3.866633778670803e-05 0.0003402491856832057 0.010489849373698235 0.017108848318457603
+0.021812669932842255 0.0005301910568960011 0.00027591249090619385 0.012777832336723804 0.025169676169753075
+0.017684923484921455 0.000465065473690629 0.0006940997554920614 0.006489828694611788 0.006159840151667595
+0.022399134933948517 0.0002841560635715723 0.00019410275854170322 0.00782009493559599 0.029922638088464737
+0.026127947494387627 0.0011030996683984995 0.00041035201866179705 0.00973381008952856 0.025420527905225754
+0.022711817175149918 5.040243559051305e-05 0.00014662316243629903 0.009643805213272572 0.008686327375471592
+0.018849220126867294 0.0003645333636086434 0.000698234885931015 0.00932006910443306 0.008763928897678852
+0.025761593133211136 2.972364927700255e-05 0.00012118323502363637 0.04538770392537117 0.013219133950769901
+0.023263372480869293 0.00028074192232452333 0.0010715114185586572 0.012469640001654625 0.021124232560396194
+0.028664138168096542 0.00018885478493757546 8.301399793708697e-05 0.0124822361394763 0.037745408713817596
+0.03877243399620056 0.001251443987712264 0.0008875387138687074 0.00641140341758728 0.05040594935417175
+0.023139439523220062 5.08770790474955e-05 0.0001271576911676675 0.008727055042982101 0.030200673267245293
+0.021490488201379776 0.0009318296797573566 0.000333877484081313 0.012636354193091393 0.019250959157943726
+0.021132078021764755 0.0002956095850095153 8.20207133074291e-05 0.006673957221210003 0.010186844505369663
+0.03133229911327362 2.2595459086005576e-05 0.0004093244788236916 0.007939525879919529 0.03434564918279648
+0.03132615610957146 0.00012075299309799448 0.00014152655785437673 0.008694232441484928 0.014806280843913555
+0.02727879025042057 5.135679748491384e-05 0.0005781048093922436 0.011873017065227032 0.014101053588092327
+0.017331477254629135 0.0003917778085451573 0.00012067729403497651 0.007669642101973295 0.0047090789303183556
+0.022186988964676857 0.0003088299126829952 0.00018002375145442784 0.007646848913282156 0.015734590590000153
+0.02069927379488945 6.392437353497371e-05 0.0008929440518841147 0.017733657732605934 0.012981781736016273
+0.024945201352238655 0.00012993236305192113 0.00033094242098741233 0.02337496168911457 0.017730440944433212
+0.036086149513721466 0.0005809668218716979 0.0014576453249901533 0.010524975135922432 0.05582072585821152
+0.015261399559676647 0.0002002944820560515 0.0002911387709900737 0.0063292705453932285 0.0026720622554421425
+0.01579345390200615 0.0018791505135595798 0.00019169325241819024 0.005902652628719807 0.0073922621086239815
+0.029266027733683586 0.0005278427270241082 6.253943865885958e-05 0.006800083443522453 0.027972517535090446
+0.02263675630092621 0.00015108771913219243 0.0003610114217735827 0.00568246329203248 0.013027460314333439
+0.022591430693864822 0.00029864837415516376 7.615711365360767e-05 0.006953917909413576 0.017274431884288788
+0.014519181102514267 6.318141822703183e-05 0.0005255067371763289 0.009717434644699097 0.0029276544228196144
+0.022295594215393066 0.0001054254244081676 0.00035053936881013215 0.009094237349927425 0.021586205810308456
+0.0199141725897789 0.00019018146849703044 0.00013430565013550222 0.010401141829788685 0.017508110031485558
+0.023922454565763474 5.6305867474293336e-05 0.00014736902085132897 0.01290268637239933 0.02590666338801384
+0.019931478425860405 6.687630957458168e-05 9.434650564799085e-05 0.012520517222583294 0.015178092755377293
+0.024847280234098434 8.562643051845953e-05 0.00023644125030841678 0.012968083843588829 0.02188362181186676
+0.026221219450235367 0.00014156549877952784 0.00025057580205611885 0.007473994046449661 0.023675911128520966
+0.022023823112249374 0.0005553844966925681 0.00011734865984180942 0.006811770610511303 0.019307686015963554
+0.02346690371632576 0.0001768455549608916 0.00026879541110247374 0.007850468158721924 0.026464780792593956
+0.020756462588906288 0.00044917271588929 0.00012186206004116684 0.00828098226338625 0.007102352567017078
+0.022922120988368988 0.0008773416629992425 0.00031141191720962524 0.0058843959122896194 0.023215876892209053
+0.016827693209052086 4.410133624332957e-05 0.0005044252029620111 0.006850524805486202 0.01174784917384386
+0.016231875866651535 0.0003644776006694883 0.001171636045910418 0.012887564487755299 0.009261393919587135
+0.027864964678883553 0.00015743300900794566 0.0001798159210011363 0.01134627778083086 0.005885071586817503
+0.02628142014145851 0.0006130608380772173 0.0011689970269799232 0.01103979256004095 0.019905727356672287
+0.026055604219436646 0.000180470960913226 0.0004974358016625047 0.015715206041932106 0.03331349417567253
+0.025424541905522346 0.0009612701833248138 0.0005618917639367282 0.015008961781859398 0.029451686888933182
+0.0225796140730381 4.664889638661407e-05 0.00025656833895482123 0.010413524694740772 0.024784091860055923
+0.019411560148000717 0.00042201197356916964 0.0009095012792386115 0.013239831663668156 0.018353110179305077
+0.02255820855498314 6.194584420882165e-05 0.0001594845234649256 0.008005681447684765 0.018764853477478027
+0.023848943412303925 0.0002733321161940694 0.00019149236322846264 0.007271007634699345 0.00937428418546915
+0.023712944239377975 0.0002721419732552022 0.0004157513321842998 0.006168179214000702 0.012484262697398663
+0.02217862941324711 6.829718768130988e-05 8.34290258353576e-05 0.009806763380765915 0.007996032945811749
+0.02824391797184944 0.0005070564802736044 0.00018495354743208736 0.016284272074699402 0.018543755635619164
+0.02293016016483307 0.00018944194016512483 0.00011723711213562638 0.01010146178305149 0.01439057756215334
+0.02842114493250847 0.00011601585720200092 0.0005290806293487549 0.011731664650142193 0.032179296016693115
+0.024790985509753227 0.00010033693979494274 0.00011150872160214931 0.009507223032414913 0.028956951573491096
+0.031317684799432755 6.117387238191441e-05 0.00018621212802827358 0.010766375809907913 0.04994293302297592
+0.020940430462360382 7.003558857832104e-05 0.0002878524537663907 0.010635444894433022 0.007711194455623627
+0.020516062155365944 4.0613605960970744e-05 0.00021218367328401655 0.0050210352055728436 0.014079735614359379
+0.033377163112163544 0.0004886469687335193 0.00010255219240207225 0.0035184198059141636 0.048735879361629486
+0.017446134239435196 0.000525421928614378 0.0005375997861847281 0.005784108769148588 0.0006303086411207914
+0.03163624554872513 0.0003080420719925314 0.00010211962216999382 0.009159965440630913 0.009811031632125378
+0.01803089678287506 0.0012835584348067641 0.000608624133747071 0.007430718746036291 0.010462498292326927
+0.019706442952156067 5.207074718782678e-05 0.00043591795838437974 0.008546140976250172 0.014425604604184628
+0.024566438049077988 0.0009084229241125286 0.002296656835824251 0.021430909633636475 0.029118910431861877
+0.02100232243537903 0.00019486714154481888 0.0004473112931009382 0.025544561445713043 0.02533694915473461
+0.023568585515022278 0.00195630663074553 0.0030912761576473713 0.012586378492414951 0.02030092477798462
+0.022036591544747353 7.059385097818449e-05 0.00031933459104038775 0.008497281931340694 0.028450291603803635
+0.026875684037804604 0.003492778865620494 0.0014746749075129628 0.012851056642830372 0.033597562462091446
+0.027514655143022537 0.0004329736693762243 0.00011935525253647938 0.0071028172969818115 0.026300830766558647
+0.023033535107970238 0.001244105282239616 0.0017721755430102348 0.008311156183481216 0.001929328776896
+0.03434603288769722 0.00019083796360064298 6.335480429697782e-05 0.006822314579039812 0.031783897429704666
+0.02372361347079277 0.0008382838568650186 0.0029339075554162264 0.015880143269896507 0.01565789431333542
+0.027459051460027695 6.879471766296774e-05 0.0006752584595233202 0.010331086814403534 0.03266240656375885
+0.023747805505990982 0.0020697673317044973 0.0008868796867318451 0.009423265233635902 0.028810100629925728
+0.023673895746469498 0.0003999538312200457 0.0003190112765878439 0.01137126237154007 0.025231601670384407
+0.01712099276483059 0.0008256251458078623 0.0005341039504855871 0.017282448709011078 0.010686506517231464
+0.023735934868454933 0.0005714145372621715 0.0001520171936135739 0.007588488515466452 0.030324220657348633
+0.022551320493221283 3.4799104469129816e-05 0.0008503013523295522 0.009250946342945099 0.011625340208411217
+0.03216658905148506 4.9469184887129813e-05 0.0002943709841929376 0.01148073747754097 0.016176927834749222
+0.030858226120471954 0.00017115364607889205 0.00014470804308075458 0.006815493106842041 0.025947164744138718
+0.02443978562951088 3.599367482820526e-05 0.00010340936569264159 0.006906128488481045 0.015030044130980968
+0.022893410176038742 0.0002804505347739905 0.00012040095316478983 0.008144055493175983 0.02448531799018383
+0.025595221668481827 0.000260150118265301 0.00011191501107532531 0.00860182661563158 0.03846345469355583
+0.03330472856760025 6.964665954001248e-05 0.00013737457629758865 0.009899471886456013 0.042326752096414566
+0.016209665685892105 8.107860048767179e-05 0.00015601227642036974 0.00887911394238472 0.006071516312658787
+0.02308916300535202 0.00020349085389170796 0.00010603423288557678 0.014906967990100384 0.009255023673176765
+0.02485077641904354 0.0005366048426367342 0.0003539101453498006 0.007468960713595152 0.010302622802555561
+0.028362737968564034 9.932603279594332e-05 0.0004568427975755185 0.00987519882619381 0.010078556835651398
+0.01995142549276352 0.0006779319955967367 0.0003323537821415812 0.007154232356697321 0.0034728669561445713
+0.020489700138568878 0.0005796508630737662 0.0002648197114467621 0.006294512189924717 0.000970347027759999
+0.027402140200138092 0.00016867267549969256 0.0008101330604404211 0.010890754871070385 0.03273859992623329
+0.019948039203882217 0.00023200605937745422 0.0010187781881541014 0.02329975925385952 0.012970661744475365
+0.024635443463921547 0.0005989065393805504 0.0016141467494890094 0.009534791111946106 0.03301069140434265
+0.030153729021549225 0.00011455440107965842 0.001475914497859776 0.007900338619947433 0.025954000651836395
+0.020866628736257553 0.0018009754130616784 0.00031342022703029215 0.00985011737793684 0.021756866946816444
+0.032788440585136414 0.0016236562514677644 0.0005892821936868131 0.009219658561050892 0.04627003148198128
+0.01864415965974331 0.00013504833623301238 0.0001687933545326814 0.009475714527070522 0.011406461708247662
+0.026913484558463097 0.0011357839684933424 0.00024247811234090477 0.007282854989171028 0.005584521219134331
+0.026477308943867683 0.000138536372105591 0.0007666972815059125 0.01160472072660923 0.0020775683224201202
+0.026805277913808823 7.594056660309434e-05 0.001511316979303956 0.010678278282284737 0.02523050084710121
+0.01720823533833027 0.00022317725233733654 0.000293962424620986 0.013695240952074528 0.010061685927212238
+0.021893564611673355 0.0008339454652741551 0.0006580987828783691 0.014769488014280796 0.021986648440361023
+0.023439966142177582 8.295565930893645e-05 0.0008947491878643632 0.016876816749572754 0.029450761154294014
+0.02344677224755287 0.0010195239447057247 0.00023187171609606594 0.010824333876371384 0.03218919783830643
+0.02236327901482582 0.002072377596050501 5.810713628306985e-05 0.009249057620763779 0.021490011364221573
+0.02047526091337204 0.00039109750650823116 0.00023662143212277442 0.006843164563179016 0.014468450099229813
+0.02237599343061447 0.00018003996228799224 0.00036729752901010215 0.006798022426664829 0.002292176242917776
+0.023664025589823723 0.00022762553999200463 0.00012254697503522038 0.005492592696100473 0.00577555550262332
+0.02700953558087349 5.656960638589226e-05 0.0012276237830519676 0.005179267376661301 0.013072816655039787
+0.016116954386234283 0.00011201304732821882 0.0015165312215685844 0.009122999384999275 0.002607033122330904
+0.02571428008377552 0.0010263347066938877 0.0005481779226101935 0.016015449538826942 0.01978922262787819
+0.02426055260002613 0.0009706209530122578 0.0006613308214582503 0.01754755899310112 0.025815537199378014
+0.02970755845308304 6.516682333312929e-05 0.0005117172258906066 0.010884292423725128 0.03340837359428406
+0.018080998212099075 0.0018352887127548456 0.00025317262043245137 0.00998302735388279 0.016479622572660446
+0.021844524890184402 0.0011794763850048184 0.00012713746400550008 0.00794754084199667 0.01739547774195671
+0.017226042225956917 4.0103976061800495e-05 0.000878616760019213 0.007387697696685791 0.006734071299433708
+0.0233675055205822 0.00010336523700971156 0.0003395246167201549 0.005562394391745329 0.011493973433971405
+0.023983053863048553 0.0001031279462040402 0.0006499473238363862 0.008874834515154362 0.0048662591725587845
+0.02010488137602806 4.605905633070506e-05 0.0008780331118032336 0.0070208036340773106 0.010631761513650417
+0.025642240419983864 0.0007266887696459889 0.0002443537814542651 0.009811989031732082 0.026324711740016937
+0.023291103541851044 0.000753010215703398 0.0003181768406648189 0.016422390937805176 0.019236218184232712
+0.030038436874747276 0.00018906081095337868 0.0003491581592243165 0.012624766677618027 0.04330043867230415
+0.025381673127412796 0.00016291355132125318 0.00027525422046892345 0.01081826537847519 0.031214267015457153
+0.01952877640724182 0.000748876656871289 0.0003071942483074963 0.010720547288656235 0.011819617822766304
+0.028988350182771683 0.0003211805014871061 0.00040008907672017813 0.007840155623853207 0.007193863857537508
+0.02033832110464573 4.119989534956403e-05 0.0008292750571854413 0.006217468064278364 0.010884687304496765
+0.02415144070982933 8.548246114514768e-05 0.0001504951942479238 0.006255850195884705 0.021480323746800423
+0.02367757074534893 0.0001368976227240637 0.0006072473479434848 0.012895232066512108 0.01799141615629196
+0.024078309535980225 0.00039679507608525455 0.0002646155480761081 0.013476542197167873 0.03092340938746929
+0.021496940404176712 0.0012788254534825683 0.0008803698001429439 0.011350651271641254 0.013552965596318245
+0.028242498636245728 4.57289679616224e-05 0.0001816377480281517 0.009059468284249306 0.03316221013665199
+0.02435707114636898 0.001802189159207046 0.0010981573723256588 0.010309353470802307 0.009742573834955692
+0.018162880092859268 0.0011087338207289577 0.0009522138861939311 0.010512047447264194 0.00047524870024062693
+0.022747516632080078 0.00028689001919701695 0.00012460794823709875 0.009306457825005054 0.0027245492674410343
+0.023988481611013412 0.0011394035536795855 0.00018218005425296724 0.011160420253872871 0.0254024937748909
+0.022409457713365555 0.0001923015806823969 0.0003275898052379489 0.011537335813045502 0.025893094018101692
+0.024800216779112816 5.7286211813334376e-05 0.0004223129362799227 0.013953412882983685 0.023856135085225105
+0.02577451430261135 4.642390558728948e-05 0.0002156868577003479 0.0117096658796072 0.028039269149303436
+0.028767114505171776 5.055023939348757e-05 0.0008094216464087367 0.008842473849654198 0.04624703526496887
+0.020387019962072372 0.00039053766522556543 0.00028481759363785386 0.009409132413566113 0.011204971931874752
+0.032454170286655426 0.0007006822270341218 0.0001221594138769433 0.018946915864944458 0.03483873978257179
+0.016319090500473976 0.0003223807143513113 9.717654756968841e-05 0.0089182760566473 0.0031948951072990894
+0.026272011920809746 3.759583705686964e-05 9.925611811922863e-05 0.008123734965920448 0.01331140287220478
+0.019136706367135048 0.0001568742882227525 0.00031798568670637906 0.03318861871957779 0.003586158389225602
+0.026738442480564117 0.00046459431177936494 0.00048042534035630524 0.012374566867947578 0.025616588070988655
+0.01900365576148033 0.0006640565115958452 0.0005635701236315072 0.014209264889359474 0.013633632101118565
+0.01780225709080696 0.00041927985148504376 0.00022735192032996565 0.013975206762552261 0.010257930494844913
+0.022651134058833122 6.945760105736554e-05 0.000282814638921991 0.010682676918804646 0.02431335486471653
+0.016257554292678833 0.0005503102438524365 0.00028847609064541757 0.008397893980145454 0.012146303430199623
+0.02921213209629059 0.0008633368997834623 0.00018205921514891088 0.008883261121809483 0.03075624443590641
+0.01986580714583397 0.0003256372583564371 0.00039149547228589654 0.007917480543255806 0.015462216921150684
+0.021951429545879364 5.767543552792631e-05 0.000234084582189098 0.008587392047047615 0.01872093975543976
+0.020703565329313278 6.707912689307705e-05 0.00018494627147447318 0.011492121033370495 0.010867039673030376
+0.019168170168995857 0.0001829677785281092 0.00021029575145803392 0.0074752322398126125 0.00709972158074379
+0.02126583456993103 0.00015208324475679547 0.00036396292853169143 0.008026250638067722 0.014360828325152397
+0.01718057319521904 0.00015107665967661887 0.0005164241301827133 0.011587423272430897 0.0035111531615257263
+0.02375122904777527 0.00031262155971489847 0.00013651455810759217 0.012583829462528229 0.01890718936920166
+0.01921224221587181 0.00018014949455391616 0.00038775074062868953 0.013170168735086918 0.01224981714040041
+0.027717262506484985 0.00015359792450908571 0.0003345748991705477 0.013639763928949833 0.03878657519817352
+0.016664762049913406 0.00045276578748598695 0.00018500258738640696 0.01876584254205227 0.008923625573515892
+0.018682442605495453 4.7313929826486856e-05 0.0005727144307456911 0.00944299902766943 0.009537206962704659
+0.019849494099617004 3.55412375938613e-05 0.0006530032842420042 0.010250319726765156 0.01084898505359888
+0.022628547623753548 0.0004508198762778193 0.00015244483074638993 0.006849002558737993 0.007830372080206871
+0.02347237803041935 0.0004933209274895489 0.0009464108152315021 0.009471542201936245 0.025351136922836304
+0.028293771669268608 0.0005531010101549327 8.930943295126781e-05 0.009230590425431728 0.027878347784280777
+0.031366124749183655 0.002046639798209071 0.0006804586737416685 0.005366188008338213 0.039285048842430115
+0.014804406091570854 7.196906517492607e-05 0.0005990713252685964 0.006938706152141094 0.009617566131055355
+0.01776377111673355 0.0010967933339998126 0.0019572549499571323 0.012051593512296677 0.007446489296853542
+0.02190333604812622 6.305272836470976e-05 0.0001604807621333748 0.010457524098455906 0.02667427249252796
+0.018698055297136307 0.0007597042713314295 0.0011829149443656206 0.008694990538060665 0.015179566107690334
+0.019856948405504227 5.37060150236357e-05 0.00035326360375620425 0.01564919389784336 0.004497429821640253
+0.022066300734877586 0.00048667285591363907 0.0003577627649065107 0.012128641828894615 0.009442642331123352
+0.022774292156100273 0.0001353188999928534 0.00022467676899395883 0.009429181925952435 0.018251143395900726
+0.02520618960261345 6.432625377783552e-05 0.0005010049208067358 0.011062873527407646 0.028439439833164215
+0.03202397748827934 6.046362250344828e-05 0.000284630514215678 0.011763068847358227 0.04103183373808861
+0.026356328278779984 3.9334372559096664e-05 0.00031159669742919505 0.007660105358809233 0.027119606733322144
+0.017735878005623817 0.00010494553134776652 7.839826866984367e-05 0.0062943375669419765 0.016722535714507103
+0.03029732219874859 0.0001324781187577173 0.00025403068866580725 0.005689414218068123 0.023568954318761826
+0.022799525409936905 0.0007545253611169755 0.00023747183149680495 0.009217952378094196 0.021211674436926842
+0.01499984785914421 0.00037608417915180326 0.00010024944640463218 0.008723119273781776 0.008224164135754108
+0.023266106843948364 0.000171211693668738 9.327733278041705e-05 0.007653145119547844 0.013177352026104927
+0.030513044446706772 0.0001816677104216069 0.0005051871412433684 0.007120159920305014 0.02382594719529152
+0.022093283012509346 0.00019365324988029897 0.0006869308999739587 0.008783143945038319 0.024800924584269524
+0.01791302114725113 0.00017220948939211667 0.00022243424609769136 0.012491858564317226 0.015715347602963448
+0.021983947604894638 0.0006735147326253355 0.0007953738677315414 0.01181875728070736 0.020999202504754066
+0.02194751240313053 2.612917342048604e-05 0.0002670029934961349 0.008756546303629875 0.027401603758335114
+0.021168895065784454 0.0009619793854653835 0.0001965832634596154 0.008506350219249725 0.021577514708042145
+0.025673028081655502 0.000424935802584514 9.167078678729013e-05 0.0083708381280303 0.02271827682852745
+0.024169843643903732 0.00021110499801579863 0.0008261293405666947 0.007825277745723724 0.01601029373705387
+0.0167558416724205 0.00020665372721850872 0.0004884477821178734 0.0071682860143482685 0.013929352164268494
+0.01715683937072754 0.00016751911607570946 0.0003163569199386984 0.006371776573359966 0.006006164941936731
+0.022893298417329788 0.0002512524661142379 0.00046771494089625776 0.007572154048830271 0.030961086973547935
+0.018932903185486794 5.7998913689516485e-05 0.00014586988254450262 0.00861283391714096 0.01805795170366764
+0.02315634861588478 0.0005737220635637641 0.00024091030354611576 0.008938639424741268 0.017642812803387642
+0.015061377547681332 0.00021684016974177212 0.00012075942504452541 0.00818005669862032 0.0017998232506215572
+0.023600442335009575 9.244338434655219e-05 0.0001866951206466183 0.00554571021348238 0.011028848588466644
+0.0315605066716671 0.0003766745503526181 0.00028606734122149646 0.005673840641975403 0.01846970058977604
+0.026681536808609962 5.770224379375577e-05 0.00028300657868385315 0.010228697210550308 0.03671666607260704
+0.024639710783958435 0.00030559339211322367 0.0007332083769142628 0.020954212173819542 0.027127698063850403
+0.02901916205883026 0.0004994029295630753 0.0014204789185896516 0.012342032045125961 0.04589639604091644
+0.024356912821531296 0.0002968245535157621 7.060253119561821e-05 0.0052680205553770065 0.030239101499319077
+0.029005711898207664 0.0016710617346689105 0.0007473659934476018 0.006560830865055323 0.029150839895009995
+0.018546216189861298 8.196656563086435e-05 9.32345719775185e-05 0.013628975488245487 0.00396389327943325
+0.026113297790288925 0.0008462417172268033 0.000458398659247905 0.007948818616569042 0.01826343685388565
+0.019019009545445442 0.0001985248236451298 0.0001789992384146899 0.006145879626274109 0.006860658526420593
+0.018933318555355072 8.251360850408673e-05 0.0012924165930598974 0.01228571031242609 0.00574828963726759
+0.012851575389504433 0.00011363952944520861 0.00016413710545748472 0.010319913737475872 0.002792573068290949
+0.02395758219063282 0.0004807574150618166 0.0005168633651919663 0.011149034835398197 0.03191721811890602
+0.03441972658038139 2.811423473758623e-05 0.00034344277810305357 0.019904831424355507 0.04938412085175514
+0.016422417014837265 0.00026545271975919604 0.00024664044030942023 0.011037500575184822 0.007332202512770891
+0.014700652100145817 0.00026318468735553324 0.0003317801747471094 0.007179738022387028 0.006995787378400564
+0.014079466462135315 0.00027864016010425985 0.0002030080941040069 0.007101192604750395 0.007050504442304373
+0.027469197288155556 0.00023458966461475939 0.00012070901721017435 0.02342148683965206 0.006450263783335686
+0.01650220900774002 5.656047142110765e-05 0.0001311279193032533 0.005274493712931871 0.00891146995127201
+0.02265104278922081 0.000376579788280651 0.00010531561565585434 0.004375911317765713 0.010709586553275585
+0.01718657836318016 0.00025460784672759473 0.00037592140142805874 0.007316399831324816 0.01704394444823265
+0.025977056473493576 0.0001371537655359134 0.0006517097936011851 0.012114537879824638 0.03258298337459564
+0.02231140248477459 0.00021527163335122168 0.000500905211083591 0.013284708373248577 0.03044079802930355
+0.027789108455181122 9.957859583664685e-05 0.0005587004707194865 0.012231922708451748 0.03604612499475479
+0.02301117777824402 0.00030325906118378043 0.00035206516622565687 0.010370769537985325 0.023280635476112366
+0.01983855850994587 0.00145303038880229 0.0001739178114803508 0.007309481035917997 0.018782250583171844
+0.024520136415958405 0.00037831359077244997 0.00010539652430452406 0.006911867763847113 0.022639567032456398
+0.025510169565677643 0.00033092338708229363 0.00040295234066434205 0.006371871568262577 0.008680731058120728
+0.021788235753774643 0.0004968038410879672 8.846174023346975e-05 0.0067246584221720695 0.004526397679001093
+0.01671559549868107 4.210391853121109e-05 0.0007018507458269596 0.010131102986633778 0.006387327332049608
+0.028272490948438644 5.3027870308142155e-05 0.000456640642369166 0.028952093794941902 0.026218052953481674
+0.021564405411481857 0.0003320280520711094 0.0004068409907631576 0.014425705187022686 0.015521546825766563
+0.020912615582346916 0.0003394444065634161 0.001043421565555036 0.016766488552093506 0.016476038843393326
+0.018406055867671967 0.00017081701662391424 0.0004602096159942448 0.010773967020213604 0.017027847468852997
+0.030024010688066483 0.0016661144327372313 0.00019251751655247062 0.008949805982410908 0.03764617070555687
+0.021648703143000603 0.0008777952753007412 6.59706856822595e-05 0.00732114352285862 0.02611217461526394
+0.015150908380746841 8.816471381578594e-05 0.00024348981969524175 0.006537816021591425 0.004993716720491648
+0.01840726099908352 0.00036170825478620827 0.00016791478265076876 0.004449908155947924 0.0051245796494185925
+0.020834624767303467 4.3247790017630905e-05 0.0012869478669017553 0.007639011368155479 0.018586207181215286
+0.021693997085094452 8.166180487023667e-05 0.0012318282388150692 0.008342599496245384 0.015236505307257175
+0.017937950789928436 0.0009304683771915734 0.00013947811385151 0.009420537389814854 0.0057013072073459625
+0.01820557378232479 0.0006345339352265 0.0004521118535194546 0.01335645467042923 0.009455287829041481
+0.021211089566349983 6.409629713743925e-05 0.00048561187577433884 0.016545148566365242 0.005895616021007299
+0.026115499436855316 0.0005245269276201725 0.00024940649745985866 0.011106887832283974 0.02185993455350399
+0.01598716527223587 0.00046791377826593816 0.0006172789726406336 0.008038285188376904 0.012102864682674408
+0.020908426493406296 0.0002682146441657096 0.0005570152425207198 0.010339788161218166 0.02242792211472988
+0.014817658811807632 0.00011233056284254417 0.0002272824349347502 0.0073271095752716064 0.007864447310566902
+0.01876649260520935 4.1304352635052055e-05 0.00013732732622884214 0.004756703972816467 0.009166963398456573
+0.015119152143597603 5.192334356252104e-05 0.00011881586397066712 0.006113219074904919 0.00452385051175952
+0.01992722973227501 3.744686546269804e-05 0.0007733040256425738 0.005071396008133888 0.016559738665819168
+0.015950890257954597 0.00014209798246156424 0.0007636622176505625 0.008671729825437069 0.013964102603495121
+0.022794801741838455 0.0007793700788170099 8.995099051389843e-05 0.011040735058486462 0.02012184076011181
+0.02225644886493683 0.0002807302516885102 0.00012450016220100224 0.007247860077768564 0.0264214389026165
+0.02086864411830902 0.0005042374832555652 0.0007780002197250724 0.01234525814652443 0.021930387243628502
+0.019515421241521835 0.00039795172051526606 0.0003673979954328388 0.014306853525340557 0.016419611871242523
+0.019084416329860687 0.00026460314984433353 0.0009677208145149052 0.010095808655023575 0.0059355744160711765
+0.026453938335180283 0.0002223733754362911 0.0011199391447007656 0.007150816731154919 0.029632069170475006
+0.020816318690776825 0.0006789990002289414 0.00030915666138753295 0.010998209938406944 0.016111044213175774
+0.0233210027217865 0.0004891901626251638 0.0002551160869188607 0.012864336371421814 0.01812804862856865
+0.018777653574943542 0.00044696242548525333 0.0010862333001568913 0.006421406753361225 0.01036028005182743
+0.017853910103440285 0.0003100864414591342 0.0007298901909962296 0.006695461925119162 0.00967865064740181
+0.02254655584692955 0.000782244373112917 0.0006932609830982983 0.0048684533685445786 0.02514704316854477
+0.01913820579648018 0.0005729247932322323 0.0010989002184942365 0.007929500192403793 0.008687656372785568
+0.020283732563257217 0.0006657634512521327 0.0002793965977616608 0.008399391546845436 0.015950290486216545
+0.01837366260588169 0.0007372836116701365 0.00026154296938329935 0.005528311710804701 0.015155085362493992
+0.022925782948732376 0.00021527674107346684 0.0012085308553650975 0.009055539034307003 0.02707459218800068
+0.02438047155737877 0.00018836505478248 0.0009331467445008457 0.008455446921288967 0.03137337416410446
+0.020727328956127167 0.0006390349590219557 0.0006125568761490285 0.006029187235981226 0.01637190207839012
+0.015781652182340622 0.0002979936543852091 0.0007339243311434984 0.008218906819820404 0.006601573899388313
+0.014523226767778397 0.0007148109725676477 0.0003287471190560609 0.010296987369656563 0.00048131236690096557
+0.0168455820530653 0.0008457937510684133 0.00031680348911322653 0.0074930982664227486 0.004976947791874409
+0.018994711339473724 9.82894780463539e-05 0.000505274860188365 0.008268509060144424 0.0029479102231562138
+0.02100432850420475 0.0003953548730351031 0.0007374876877292991 0.009097818285226822 0.007814147509634495
+0.019898273050785065 3.479624501778744e-05 0.0003033763205166906 0.014365959912538528 0.009564640931785107
+0.013040026649832726 7.567436841782182e-05 0.00011151022772537544 0.008105546236038208 0.0065795499831438065
+0.018948335200548172 2.8755090170307085e-05 8.58057028381154e-05 0.005748868454247713 0.015843190252780914
+0.020628057420253754 4.3603358790278435e-05 9.383238648297265e-05 0.006674619857221842 0.020555749535560608
+0.021146081387996674 2.7163794584339485e-05 0.00011005190026480705 0.007977182045578957 0.0245741605758667
+0.01986890286207199 3.902205571648665e-05 5.702810449292883e-05 0.00503057474270463 0.003358756424859166
+0.02028859406709671 4.671896749641746e-05 0.00024431449128314853 0.003897940507158637 0.005216821562498808
+0.020530812442302704 3.835366806015372e-05 0.00037851568777114153 0.005179520696401596 0.015398229472339153
+0.013273282907903194 0.00012509840598795563 0.00018494394316803664 0.005794797092676163 0.010828977450728416
+0.0202195905148983 0.00042136982665397227 0.0003020993899554014 0.00928608886897564 0.02214628830552101
+0.018003201112151146 0.00011491294571897015 0.0002704964717850089 0.008891160599887371 0.01073821634054184
+0.021179260686039925 0.00018975452985614538 0.00031869960366748273 0.007593594025820494 0.0039007908198982477
+0.01871384307742119 0.0004956436459906399 0.00012841397256124765 0.00874288845807314 0.01074394304305315
+0.012336168438196182 0.00016502922517247498 0.00019830501696560532 0.006237077992409468 0.000696180562954396
+0.01872853934764862 9.543100895825773e-05 0.000448721693828702 0.007911200635135174 0.014490358531475067
+0.016888514161109924 0.00015313620679080486 0.00021425278100650758 0.009614012204110622 0.014404316432774067
+0.018645066767930984 4.1958137444453314e-05 0.00011012357572326437 0.009587083011865616 0.020088717341423035
+0.013193601742386818 6.014022801537067e-05 0.00018598861061036587 0.0076713538728654385 0.007164370734244585
+0.02532515488564968 5.833465911564417e-05 0.0001497483317507431 0.006159450393170118 0.03172788396477699
+0.030706938356161118 6.016560655552894e-05 0.00013810618838761002 0.005659705027937889 0.024620894342660904
+0.01723693311214447 7.263651059474796e-05 0.00020649250654969364 0.006266494281589985 0.0060038804076612
+0.013000906445086002 8.163720485754311e-05 0.000201437680516392 0.007091288920491934 0.0006309004966169596
+0.02791774645447731 4.43259887106251e-05 0.0003667713899631053 0.007861613295972347 0.019693676382303238
+0.021951250731945038 9.956717258319259e-05 0.0003454366815276444 0.010142238810658455 0.019980864599347115
+0.017435062676668167 0.000340119528118521 0.00017595547251403332 0.01474214531481266 0.00850029569119215
+0.02030445821583271 0.00020562350982800126 0.00023723853519186378 0.011989705264568329 0.022982370108366013
+0.01436290517449379 0.00010238808317808434 0.0001011554995784536 0.013442561030387878 0.003850592067465186
+0.02289688214659691 0.00024999462766572833 0.00016998279897961766 0.009935912676155567 0.017276376485824585
+0.01866244524717331 0.00010402980842627585 0.0006371784256771207 0.00737965852022171 0.01081625185906887
+0.014935112558305264 0.00024534791009500623 0.00014751330309081823 0.006087442860007286 0.005856516305357218
+0.021272622048854828 0.0003377248940523714 0.0001594691857462749 0.003981613554060459 0.017362739890813828
+0.024048548191785812 4.4029617129126564e-05 0.00024354954075533897 0.004532797262072563 0.027217740193009377
+0.019390597939491272 0.0004495444009080529 0.0002895216748584062 0.0059032435528934 0.0160344410687685
+0.016442550346255302 0.0006968453526496887 0.00021628000831697136 0.006758685689419508 0.010559835471212864
+0.018574167042970657 0.00015149133105296642 0.00048761311336420476 0.02577385865151882 0.013325332663953304
+0.01851479336619377 0.00011984502634732053 0.0007015011506155133 0.008878006599843502 0.008869354613125324
+0.015858100727200508 6.992110866121948e-05 0.00025498351897113025 0.013906491920351982 0.006209452636539936
+0.015123213641345501 0.00033594417618587613 0.0010324819013476372 0.011248317547142506 0.009878246113657951
+0.026343878358602524 5.5688862630631775e-05 0.0007516068872064352 0.007452959660440683 0.022899389266967773
+0.0226452499628067 0.0010353567777201533 0.00011850139708258212 0.006025781389325857 0.014895796775817871
+0.01964987814426422 0.000660319346934557 0.00010810936510097235 0.006808919832110405 0.01741643436253071
+0.02040751278400421 9.138743916992098e-05 0.0005763683584518731 0.004733954556286335 0.018024371936917305
+0.01694854348897934 0.00023576797684654593 0.0002785225515253842 0.005250556860119104 0.011624153703451157
+0.02357829175889492 8.413024625042453e-05 0.00038518873043358326 0.00689672352746129 0.02620561420917511
+0.02833217941224575 6.532039697049186e-05 0.0006518585723824799 0.006972954608500004 0.041320495307445526
+0.01428439375013113 0.0003412258520256728 0.00011729427933460101 0.006438081152737141 0.011949682608246803
+0.016280602663755417 0.0005380717338994145 0.00014832367014605552 0.008677179925143719 0.01610841043293476
+0.018833529204130173 0.00015027885092422366 0.00031873458647169173 0.004232077393680811 0.005107346922159195
+0.01634420081973076 0.0006579741020686924 0.0006959220627322793 0.007694643456488848 0.0051806652918457985
+0.01898249238729477 3.135851875413209e-05 0.00015292504394892603 0.009009018540382385 0.023083116859197617
+0.025348413735628128 0.0005212412215769291 0.0009274429176002741 0.00809088908135891 0.025148581713438034
+0.01774117536842823 2.8233471311978064e-05 0.0002091613714583218 0.008019886910915375 0.017259005457162857
+0.0257401280105114 0.0005392259336076677 0.00038096372736617923 0.01081961952149868 0.03865405544638634
+0.019157299771904945 4.449713014764711e-05 0.00018635354354046285 0.011355581693351269 0.023769743740558624
+0.02392524480819702 0.0004961350932717323 0.0015559702878817916 0.007368654478341341 0.026751240715384483
+0.022840797901153564 6.227229459909722e-05 0.00018058541172649711 0.008211701177060604 0.012796442024409771
+0.02645164355635643 0.0017290699761360884 0.0015120900934562087 0.006640092469751835 0.012194642797112465
+0.023796478286385536 0.0001986527640838176 0.0008235774002969265 0.006315749604254961 0.016539789736270905
+0.02342185005545616 0.001204684260301292 9.950764069799334e-05 0.010803808458149433 0.020944256335496902
+0.016514994204044342 0.0015418378170579672 0.00013872513954993337 0.013780961744487286 0.011092700995504856
+0.026860876008868217 4.718397030956112e-05 0.0006910935626365244 0.01037634164094925 0.026809794828295708
+0.018712636083364487 0.00033310556318610907 0.0009851108770817518 0.010205304250121117 0.01078044157475233
+0.011806907132267952 3.5072491300525144e-05 0.00019214478379581124 0.014266244135797024 0.0015125873032957315
+0.01576998457312584 0.0004305553447920829 0.002254195511341095 0.008976609446108341 0.006320560351014137
+0.028016895055770874 0.00027644893270917237 0.0006032107048667967 0.00938422791659832 0.04244265705347061
+0.02267177775502205 0.0020373943261802197 0.00015722092939540744 0.008481508120894432 0.021726112812757492
+0.02082146145403385 0.000238466338487342 0.00023394331219606102 0.006748242769390345 0.019078979268670082
+0.018930155783891678 0.0003565162478480488 0.0008186248014681041 0.0066602798178792 0.009696315973997116
+0.024601098150014877 5.6359171139774844e-05 0.00014891233877278864 0.006272789090871811 0.018447501584887505
+0.023949049413204193 7.959105278132483e-05 0.0010597709333524108 0.007575783412903547 0.02695951610803604
+0.016281282529234886 0.00021849018230568618 0.0004355466226115823 0.00830281712114811 0.0017009286675602198
+0.01623556949198246 0.0006736788782291114 0.0003188670671079308 0.008532388135790825 0.009049620479345322
+0.01970032975077629 7.568440196337178e-05 0.00039056994137354195 0.015258305706083775 0.018736345693469048
+0.01668384112417698 0.0001752202515490353 0.0003721461689565331 0.013271342031657696 0.015233946032822132
+0.013399076648056507 4.903561421087943e-05 0.000598723127041012 0.0064807771705091 0.004164872225373983
+0.016153115779161453 6.464410398621112e-05 0.0005722844507545233 0.005602643359452486 0.004154455382376909
+0.02118060737848282 0.0008689945098012686 0.00019162350508850068 0.006369535345584154 0.01563209295272827
+0.02194274589419365 0.00040902357432059944 0.00015280254592653364 0.004719403572380543 0.013271719217300415
+0.020154204219579697 0.00037102095666341484 0.0003535869182087481 0.005326639395207167 0.005597691982984543
+0.022954091429710388 0.0004763970209751278 0.0001807961380109191 0.003993208985775709 0.02176838368177414
+0.02140957862138748 0.00017824287351686507 0.0014178394339978695 0.009413780644536018 0.016654707491397858
+0.02091137133538723 6.316656799754128e-05 0.0009034769027493894 0.01207212544977665 0.021959835663437843
+0.01677171140909195 0.0008339271880686283 0.0009642605436965823 0.010906881652772427 0.004334193654358387
+0.027423128485679626 5.8355693909106776e-05 0.0003867647028528154 0.008523468859493732 0.029755430296063423
+0.02559080347418785 0.0013479209737852216 0.0011500074760988355 0.016580119729042053 0.022979622706770897
+0.016325470060110092 0.00014534882211592048 0.00016727235924918205 0.011014346964657307 0.0055187842808663845
+0.023632515221834183 0.000321537401759997 0.0009617582545615733 0.00485884677618742 0.012127620168030262
+0.020616166293621063 5.44680115126539e-05 0.0001819707831600681 0.00921154860407114 0.012854264117777348
+0.015031426213681698 0.00014283672499004751 0.0005277455202303827 0.012764329090714455 0.0006495678098872304
+0.019310258328914642 0.0004884959780611098 0.0003853975795209408 0.006452319677919149 0.017227105796337128
+0.022431600838899612 0.00031300645787268877 0.0005023851990699768 0.006841110531240702 0.017750976607203484
+0.03235583007335663 0.00031690136529505253 0.00013674968795385212 0.010731704533100128 0.04570579156279564
+0.017600344493985176 0.0004783521580975503 0.00015127018559724092 0.010560671798884869 0.013434150256216526
+0.024418160319328308 2.8596356060006656e-05 0.0003400639398023486 0.009121648035943508 0.027649546042084694
+0.014864293858408928 0.0001127984796767123 0.0002035704965237528 0.007409755140542984 0.010379830375313759
+0.019033612683415413 0.00018149444076698273 0.0012271812884137034 0.007040777709335089 0.0066715688444674015
+0.014637594111263752 0.0006367055466398597 9.610933921067044e-05 0.006849967408925295 0.0043680439703166485
+0.019759787246584892 0.0010381536558270454 8.65185575094074e-05 0.004882087465375662 0.009418917819857597
+0.014403259381651878 6.994893192313612e-05 0.0008789921994321048 0.010097990743815899 0.00996581930667162
+0.01776304841041565 6.65335392113775e-05 0.0002669687964953482 0.017166057601571083 0.018852289766073227
+0.017862943932414055 0.0005193055258132517 0.0018269283464178443 0.010608820244669914 0.012007900513708591
+0.02030562236905098 3.255926276324317e-05 0.0014014181215316057 0.006368068512529135 0.017088914290070534
+0.02153845503926277 0.0025180720258504152 0.00032915567862801254 0.009295124560594559 0.01975664310157299
+0.018260443583130836 0.0012803111458197236 6.856008258182555e-05 0.006308695767074823 0.008778170682489872
+0.016471216455101967 0.00020833921735174954 0.0006020231521688402 0.005910354666411877 0.004056416451931
+0.018115434795618057 0.0005796084878966212 0.00013695920642931014 0.005015104543417692 0.008352220989763737
+0.023949237540364265 8.989313209895045e-05 0.0008130256901495159 0.009451731108129025 0.02681148797273636
+0.022710638120770454 0.00014891440514475107 0.0006029590149410069 0.009449299424886703 0.031934235244989395
+0.019504817202687263 0.0007493326556868851 0.00025690512848086655 0.008020354434847832 0.009984190575778484
+0.025656169280409813 0.00011579834972508252 0.00029975693905726075 0.008893845602869987 0.037017107009887695
+0.0186146292835474 0.0011639680014923215 0.0003471717063803226 0.009365884587168694 0.01967066153883934
+0.025299014523625374 0.0010434312280267477 0.00023055441852193326 0.008443407714366913 0.02476697973906994
+0.024440579116344452 5.1704704674193636e-05 0.00028204204863868654 0.0073076593689620495 0.013589300215244293
+0.02259906567633152 0.00029749423265457153 0.00021619726612698287 0.007529955357313156 0.013228482566773891
+0.019568905234336853 6.41741935396567e-05 0.0006400173297151923 0.009911407716572285 0.005159641150385141
+0.020581506192684174 4.9127334932563826e-05 0.0005989266210235655 0.01565423235297203 0.0177310761064291
+0.016042035073041916 0.0008604543982073665 0.0003114196297246963 0.012730006128549576 0.012988045811653137
+0.024624720215797424 0.00042386967106722295 0.0006244766991585493 0.01150580681860447 0.038125958293676376
+0.023573163896799088 0.00038181396666914225 8.555631939088926e-05 0.012423604726791382 0.03219347447156906
+0.030082382261753082 0.0009434358798898757 0.00012068689102306962 0.013596763834357262 0.03398239612579346
+0.013264216482639313 0.00011440263187978417 0.0006342987762764096 0.009564965032041073 0.0031363440211862326
+0.023938599973917007 0.0001135302591137588 0.00015120662283152342 0.007965338416397572 0.0156411025673151
+0.026681317016482353 0.0005087677272967994 0.0010744822211563587 0.005601105280220509 0.007089509163051844
+0.01713707484304905 0.00010766877676360309 0.0009125664946623147 0.007552013266831636 0.00952987838536501
+0.017830301076173782 0.0023820982314646244 6.860146822873503e-05 0.010072270408272743 0.010323402471840382
+0.020871423184871674 0.0017731758998706937 0.00011254780110903084 0.012539601884782314 0.024037720635533333
+0.022019971162080765 4.0908031223807484e-05 0.0002521515707485378 0.011483744718134403 0.02515658177435398
+0.018052902072668076 0.0005255405558273196 0.00011521509441081434 0.009797291830182076 0.01932237669825554
+0.01807531714439392 0.000468298327177763 0.0006262414972297847 0.009555742144584656 0.016033241525292397
+0.02185462787747383 0.0005746427341364324 0.0005834093899466097 0.010962252505123615 0.01772145926952362
+0.020921753719449043 0.0004979136865586042 0.00010895056766457856 0.00956189725548029 0.015378139913082123
+0.017582911998033524 7.259591802721843e-05 4.527247074292973e-05 0.0046568927355110645 0.0058198971673846245
+0.020868375897407532 0.0001324280456174165 0.00020173958910163492 0.0040132226422429085 0.013570698909461498
+0.023286905139684677 0.00038940785452723503 0.0007492417353205383 0.008767531253397465 0.02753196842968464
+0.025423109531402588 0.0009787880117073655 0.0006036506383679807 0.015164725482463837 0.025184666737914085
+0.02509145252406597 0.001009765430353582 0.0004456464375834912 0.013320190832018852 0.026067180559039116
+0.014931262470781803 3.2266179914586246e-05 0.00022185516718309373 0.00998213142156601 0.00894459243863821
+0.022550389170646667 0.0016222851118072867 0.0003067247453145683 0.015285724774003029 0.017542075365781784
+0.029397346079349518 0.001097974949516356 0.00016617128858342767 0.011261261999607086 0.03692900016903877
+0.017554335296154022 2.2831240130471997e-05 0.0006361507112160325 0.007818245328962803 0.0002611537929624319
+0.016115732491016388 9.939647134160623e-05 0.0002511190250515938 0.008224757388234138 0.0037479132879525423
+0.027449242770671844 2.868025694624521e-05 0.0001515268231742084 0.007937523536384106 0.030796218663454056
+0.022598784416913986 8.569290366722271e-05 0.0001768158545019105 0.009563089348375797 0.025153033435344696
+0.026840949431061745 0.0006389113841578364 0.00019900090410374105 0.005585810635238886 0.0282625462859869
+0.028942877426743507 0.0001522697857581079 0.00018154940335080028 0.006703195162117481 0.04778944328427315
+0.02096031978726387 0.0002576791448518634 6.442085577873513e-05 0.005932371132075787 0.0214516744017601
+0.018725957721471786 0.0004929124843329191 5.591319495579228e-05 0.009116177447140217 0.009350741282105446
+0.021616365760564804 0.0003496767021715641 5.658032387145795e-05 0.010295065119862556 0.006993246264755726
+0.019582059234380722 0.00031177644268609583 0.0003465144836809486 0.005906997248530388 0.002622221829369664
+0.02242196910083294 5.5437754781451076e-05 0.0005512997158803046 0.00595185486599803 0.002914777956902981
+0.02290630340576172 0.0011382077354937792 0.00011477770749479532 0.011735806241631508 0.029360070824623108
+0.023627426475286484 0.0015637631295248866 0.0001933087332872674 0.0105845732614398 0.02499154768884182
+0.025723226368427277 0.00014002167154103518 0.00011651881504803896 0.013369398191571236 0.03854363411664963
+0.02587929740548134 0.0002480624825693667 0.00011973494110861793 0.012955897487699986 0.03278842940926552
+0.024845674633979797 0.00023004993272479624 0.0006132922717370093 0.008076912723481655 0.02329115942120552
+0.016546525061130524 0.00046045874478295445 0.0005370075232349336 0.008234829641878605 0.006645058281719685
+0.023976586759090424 0.0009125845972448587 4.5537814003182575e-05 0.005420949310064316 0.018017854541540146
+0.020771970972418785 0.00018761481624096632 7.48766033211723e-05 0.00604346813634038 0.0044649154879152775
+0.022772155702114105 0.00041014672024175525 9.874229726847261e-05 0.005717877298593521 0.01839372329413891
+0.01858215406537056 0.0005201577441766858 0.0002461689873598516 0.013308973982930183 0.016863396391272545
+0.023960251361131668 0.00031484547071158886 0.00048761017387732863 0.021008498966693878 0.01523043867200613
+0.027976833283901215 0.0003894653345923871 0.00011003524559782818 0.013388758525252342 0.038766320794820786
+0.022038539871573448 0.0001340815215371549 0.00015701379743404686 0.014402204193174839 0.02424498274922371
+0.01985759101808071 0.0006308745942078531 0.00023636920377612114 0.015557411126792431 0.010604617185890675
+0.027198072522878647 0.0001934913598233834 0.00010218218085356057 0.01359142828732729 0.011971775442361832
+0.019754573702812195 8.931318006943911e-05 0.0007311358349397779 0.010831108316779137 0.00618788180872798
+0.022633623331785202 0.00017608908819966018 5.171275552129373e-05 0.04180515557527542 0.010927392169833183
+0.022361498326063156 0.0003797566459979862 0.0005260378820821643 0.012352903373539448 0.010591921396553516
+0.019612658768892288 0.000339314341545105 8.424617408309132e-05 0.00853536743670702 0.010669677518308163
+0.025429535657167435 0.0007770519005134702 0.00028182423557154834 0.009992117993533611 0.03731914609670639
+0.024253545328974724 6.280469824559987e-05 0.00034306911402381957 0.015902608633041382 0.030524423345923424
+0.02160259708762169 0.00015358311065938324 0.00023485562996938825 0.008746073581278324 0.02480989694595337
+0.021634835749864578 7.519384234910831e-05 9.450861398363486e-05 0.006732506211847067 0.023941269144415855
+0.01742197386920452 0.00012129398965043947 8.953330689109862e-05 0.007030731067061424 9.883784514386207e-05
+0.019574353471398354 0.00010292114893672988 0.0001346335920970887 0.007661894429475069 0.01267450675368309
+0.02839442901313305 4.942261512042023e-05 9.804784349398687e-05 0.005162409041076899 0.011047817766666412
+0.017884641885757446 4.175330468569882e-05 0.0002972691727336496 0.0077454405836761 0.003200882114470005
+0.020249132066965103 6.512505933642387e-05 0.00035852266591973603 0.013501271605491638 0.01714063249528408
+0.017854025587439537 0.0005179011495783925 0.0007452791905961931 0.01262133289128542 0.016598548740148544
+0.025656672194600105 0.0002644068154040724 0.0010095364414155483 0.01721651665866375 0.039458319544792175
+0.02835143730044365 0.00028423042385838926 0.0005058056558482349 0.010376721620559692 0.036039624363183975
+0.01821482926607132 0.000804328708909452 0.0001548315049149096 0.01021239347755909 0.012579197995364666
+0.01998545043170452 0.00021839900000486523 0.00026180752320215106 0.008023614063858986 0.010328719392418861
+0.018697872757911682 3.2154588552657515e-05 0.00025760935386642814 0.006486093625426292 0.008670050650835037
+0.018597666174173355 3.663307870738208e-05 0.00038986970321275294 0.006266332231462002 0.017679225653409958
+0.018205828964710236 7.571589958388358e-05 0.0007233681972138584 0.010302710346877575 0.0021260573994368315
+0.019070420414209366 0.0008528729667887092 0.00018045080651063472 0.00815874058753252 0.014437026344239712
+0.01604440249502659 0.0004715987597592175 0.0002144814352504909 0.008425388485193253 0.011455761268734932
+0.021259868517518044 3.08915987261571e-05 0.0003169361734762788 0.011475118808448315 0.011661351658403873
+0.018870964646339417 7.262043072842062e-05 0.00036506939795799553 0.008987950161099434 0.014662418514490128
+0.01295578945428133 7.481945067411289e-05 0.0005114517989568412 0.009160484187304974 0.005402986891567707
+0.014751644805073738 0.00028729930636473 0.0002903894055634737 0.008068681694567204 0.009913597255945206
+0.018191512674093246 0.00035849507548846304 0.00011865761916851625 0.01101850625127554 0.013131448067724705
+0.01956687495112419 2.7533051252248697e-05 0.00021385334548540413 0.007993092760443687 0.0063122231513261795
+0.01667097397148609 0.00011134350643260404 0.00014891373575665057 0.00600818358361721 0.009433449245989323
+0.028054185211658478 3.18963429890573e-05 0.0004123375692870468 0.007503589615225792 0.03803795948624611
+0.01776888594031334 4.8786878323880956e-05 0.00042360974475741386 0.014370737597346306 0.01895778626203537
+0.0186908021569252 0.0007557297358289361 0.0002572642406448722 0.008727931417524815 0.018414510414004326
+0.016444403678178787 0.0001390453107887879 0.00017249326629098505 0.008001547306776047 0.0015157784800976515
+0.016699939966201782 0.0007506711990572512 0.00047046245890669525 0.0086275739595294 0.00756961852312088
+0.011213716119527817 0.0005470180185511708 0.0002692253328859806 0.008717091754078865 0.002926892600953579
+0.013701016083359718 0.00016907807730603963 0.0003986309457104653 0.010367196053266525 0.007357001304626465
+0.016061382368206978 0.00029063771944493055 0.00043688149889931083 0.007752292789518833 0.015070908702909946
+0.019168853759765625 0.000155515837832354 0.00018146782531403005 0.007278610486537218 0.014985247515141964
+0.016813665628433228 0.00023188212071545422 0.0003730167809408158 0.008378471247851849 0.015763282775878906
+0.015956446528434753 0.000304184592096135 0.0002994728565681726 0.007564975414425135 0.013376440852880478
+0.01926935464143753 0.0002209380327258259 0.0002117476105922833 0.006242468487471342 0.002967341337352991
+0.022468339651823044 0.00025799524155445397 0.00026515632634982467 0.0038311101961880922 0.013354238122701645
+0.014459165744483471 0.00044373978744260967 0.0005263160564936697 0.004355818033218384 0.0023043244145810604
+0.023376883938908577 8.929053001338616e-05 5.507696914719418e-05 0.008045054040849209 0.009280961938202381
+0.016099678352475166 0.0005767700495198369 0.00034131528809666634 0.011882655322551727 0.00908801332116127
+0.017838623374700546 9.343020792584866e-05 7.681216084165499e-05 0.009305394254624844 0.019117645919322968
+0.022386007010936737 0.000201819246285595 0.00017566006863489747 0.007691275328397751 0.03500950708985329
+0.016311034560203552 4.959764555678703e-05 0.00012262431846465915 0.005763302091509104 0.014318214729428291
+0.014791007153689861 9.516957652522251e-05 0.0006505436613224447 0.005230739247053862 0.004482558462768793
+0.02321217767894268 0.0001112192403525114 0.0001291567605221644 0.0035775029100477695 0.020212609320878983
+0.016299579292535782 0.0005528065375983715 0.00013994943583384156 0.004524510353803635 0.007145632524043322
+0.018111128360033035 6.897454295540228e-05 4.427013482199982e-05 0.004183339886367321 0.018531158566474915
+0.017318567261099815 0.00016405130736529827 7.949458813527599e-05 0.004876594990491867 0.005803946405649185
+0.018540099263191223 0.0001812848204281181 0.00011817468475783244 0.009354441426694393 0.013212918303906918
+0.015526363626122475 6.876301631564274e-05 0.00028782078879885375 0.00894712470471859 0.010770066641271114
+0.018213428556919098 8.369032730115578e-05 0.0002112948859576136 0.01196613535284996 0.012819310650229454
+0.01759304851293564 0.00010399332677479833 9.920586307998747e-05 0.008364849723875523 0.01882246695458889
+0.01938994973897934 3.735591235454194e-05 0.00011598980199778453 0.0062201255932450294 0.020885633304715157
+0.012514305301010609 0.0002328278060304001 6.1277904023882e-05 0.005871783476322889 0.008107408881187439
+0.015071163885295391 0.00021054534590803087 5.1194823754485697e-05 0.005010250490158796 0.010318241082131863
+0.01602860540151596 4.298926069168374e-05 7.710252975812182e-05 0.004729713778942823 0.005112841725349426
+0.025991907343268394 4.483551310840994e-05 7.835400174371898e-05 0.003099061781540513 0.019111309200525284
+0.017942046746611595 7.327338244067505e-05 0.00043795062811113894 0.0036903424188494682 0.017084240913391113
+0.01902187056839466 3.810786074609496e-05 0.0006164806545712054 0.006823322735726833 0.019741181284189224
+0.017468217760324478 0.0007909092237241566 0.00017265407950617373 0.009282086975872517 0.0164791326969862
+0.01210701558738947 0.0007333356188610196 0.00033348743454553187 0.011103943921625614 0.0031564347445964813
+0.01531150285154581 0.0001985790440812707 0.0002868089359253645 0.009533762000501156 0.011064643040299416
+0.013479022309184074 0.0009217981132678688 0.00035733505501411855 0.010355609469115734 0.0035097470972687006
+0.021273525431752205 3.524907515384257e-05 0.0004375818243715912 0.0110134556889534 0.03018011339008808
+0.014396720565855503 0.0004069585702382028 0.001678862376138568 0.009951794520020485 0.007310394663363695
+0.017230618745088577 5.949243131908588e-05 0.00022791857190895826 0.005512130446732044 0.006694598123431206
+0.02008955366909504 0.0012152532581239939 0.0006392785580828786 0.007104386575520039 0.007217179052531719
+0.017432808876037598 7.774781988700852e-05 0.00013517522893380374 0.00432212371379137 0.010133280418813229
+0.01987875998020172 0.0009888972854241729 0.00044158968375995755 0.008672405034303665 0.014694245532155037
+0.02296094037592411 0.0004185835423413664 0.00010333168029319495 0.0065768687054514885 0.03396175429224968
+0.01198384165763855 0.00028757061227224767 0.000923499115742743 0.004238791298121214 0.007810895796865225
+0.023141706362366676 0.00021087181812617928 0.0005840777303092182 0.006109829992055893 0.02829851023852825
+0.014461744576692581 0.00019122073717880994 0.0001176300720544532 0.010672232136130333 0.0051530203782022
+0.01996220089495182 0.00032577753881923854 0.00029248601640574634 0.011989247053861618 0.020846566185355186
+0.020023208111524582 8.087568858172745e-05 0.00016469070396851748 0.011355317197740078 0.015182990580797195
+0.021332615986466408 0.00034866182249970734 0.00020058998779859394 0.006278062239289284 0.0014554247027263045
+0.019013116136193275 6.215445318957791e-05 0.00019675523799378425 0.0069814096204936504 0.0009432195802219212
+0.009622449986636639 0.00017684356134850532 0.00032159272814169526 0.009048931300640106 0.0004138777148909867
+0.015223830007016659 0.0001542707032058388 0.00031864625634625554 0.008665923960506916 0.01569892093539238
+0.025627978146076202 4.157217335887253e-05 0.00013055287126917392 0.007721122354269028 0.02473493292927742
+0.014954215846955776 0.00014443208056036383 0.0001338858128292486 0.006878928747028112 0.013627168722450733
+0.018358491361141205 7.101115625118837e-05 0.00014519382966682315 0.005656978581100702 0.012661296874284744
+0.017193295061588287 2.7151658287039027e-05 7.234721852000803e-05 0.007509899791330099 0.018626410514116287
+0.011980531737208366 3.938171357731335e-05 6.098033190937713e-05 0.007407346740365028 0.003747716546058655
+0.014941140078008175 5.562648948398419e-05 0.00023098464589565992 0.00283034541644156 0.004759728442877531
+0.022070199251174927 3.160679989377968e-05 0.0003864225873257965 0.0034737137611955404 0.010128241963684559
+0.016639403998851776 0.0002844943664968014 7.711134821875021e-05 0.005249613896012306 0.010589930228888988
+0.01903412863612175 0.00030714491731487215 0.0001821738842409104 0.005955686792731285 0.022314537316560745
+0.015107075683772564 7.387696678051725e-05 0.0002695659059099853 0.01006343960762024 0.009872562251985073
+0.014823942445218563 0.0001322664029430598 0.00014164460299070925 0.009447424672544003 0.01329185999929905
+0.01609719730913639 3.596706301323138e-05 0.0005346113466657698 0.008253621868789196 0.009191320277750492
+0.023207662627100945 5.122080983710475e-05 0.00026074444758705795 0.010605625808238983 0.02337772585451603
+0.01989738643169403 0.0002610500087030232 0.00015578692546114326 0.010932087898254395 0.012683775275945663
+0.01632677949965 6.978621968301013e-05 7.249036571010947e-05 0.006688467226922512 0.010018236935138702
+0.0177131537348032 9.353440691484138e-05 0.00010844838107004762 0.006131697446107864 0.014958014711737633
+0.01763567514717579 3.789708352996968e-05 0.00011620131408562884 0.00624003354460001 0.009142563678324223
+0.01609598472714424 3.3579326554900035e-05 0.00021008870680816472 0.0067595867440104485 0.009082511998713017
+0.014803105965256691 4.859953696723096e-05 0.0001580791868036613 0.013394691050052643 0.004333061631768942
+0.018232103437185287 8.241672912845388e-05 7.610133616253734e-05 0.00470638507977128 0.015307070687413216
+0.021666981279850006 3.091124017373659e-05 0.0001332256360910833 0.011305146850645542 0.0241651963442564
+0.021298352628946304 5.896230140933767e-05 0.00011718304449459538 0.008448985405266285 0.026346612721681595
+0.013839242979884148 2.924156069639139e-05 0.00021791366452816874 0.005790945142507553 0.007604369428008795
+0.01625918224453926 0.00023574016813654453 7.818166341166943e-05 0.0051976884715259075 0.0075913588516414165
+0.018007472157478333 0.00034937000600621104 0.00015711279411334544 0.005276467651128769 0.02186875231564045
+0.022723674774169922 0.0001020936033455655 0.0001784760970622301 0.006899967323988676 0.023450637236237526
+0.02154172770678997 0.0004214656655676663 0.0004539633519016206 0.0044109816662967205 0.010106311179697514
+0.021712519228458405 8.141563012031838e-05 0.00011898916272912174 0.005396883934736252 0.013971582986414433
+0.019442643970251083 0.0005805640830658376 0.0010850761318579316 0.008569932542741299 0.01595122739672661
+0.01651756279170513 0.00027920264983549714 0.00017792174185160547 0.006897398736327887 0.011975672096014023
+0.01926581561565399 0.0008837752975523472 0.0008668951340951025 0.00625438429415226 0.008489868603646755
+0.016918916255235672 0.0002722641220316291 0.00038076716009527445 0.01320485956966877 0.008580798283219337
+0.011953463777899742 0.0006002297159284353 0.0006781611009500921 0.013504059053957462 0.0038119202945381403
+0.018735088407993317 0.0002995611575897783 0.0008039689855650067 0.0058204433880746365 0.022320706397294998
+0.016627706587314606 0.0001617670350242406 0.0005379404174163938 0.007307151798158884 0.012825087644159794
+0.012710939161479473 0.0003574928268790245 0.00019257722306065261 0.008500383235514164 0.003309956518933177
+0.015410108491778374 0.0001495964970672503 8.882809925125912e-05 0.0034259746316820383 0.010733788833022118
+0.012478439137339592 2.613107062643394e-05 0.0001304334873566404 0.004747496917843819 0.0008680788450874388
+0.012468540109694004 3.0357616196852177e-05 0.000250755634624511 0.005417489446699619 0.002207597019150853
+0.012249385006725788 6.896112608956173e-05 0.00022870433167554438 0.006649654358625412 0.007717210333794355
+0.015248223207890987 0.00026072614127770066 0.00019891293777618557 0.007477042265236378 0.0036678414326161146
+0.02186867967247963 4.9366110033588484e-05 0.00019055409939028323 0.008019112050533295 0.021896153688430786
+0.01759207807481289 0.0001095136467483826 0.0003576045564841479 0.00998432096093893 0.0069382996298372746
+0.016393154859542847 3.05712383124046e-05 7.874049333622679e-05 0.010242694057524204 0.017129695042967796
+0.015594814904034138 0.00012599404726643115 0.0005068824393674731 0.006860822904855013 0.013843772932887077
+0.01474747247993946 9.861370926955715e-05 0.00012843715376220644 0.00586755620315671 0.011381375603377819
+0.012608909048140049 0.0004927298869006336 0.00011560319398995489 0.006284111645072699 0.00669506099075079
+0.01844920963048935 2.1834237486473285e-05 0.00018519029254093766 0.004404556471854448 0.008254550397396088
+0.016911379992961884 0.0001010242267511785 0.00017103186110034585 0.0044893622398376465 0.008857588283717632
+0.015573394484817982 3.569707041606307e-05 0.00010901656787609681 0.005511123687028885 0.010763760656118393
+0.013995341956615448 5.116756437928416e-05 0.00022300145064946264 0.007287611719220877 0.007309088949114084
+0.01082745660096407 0.00015183490177150816 0.0001053260057233274 0.006629098206758499 0.0025450955145061016
+0.01666298694908619 0.00011309317778795958 0.00010177143849432468 0.0073385899886488914 0.0006052266689948738
+0.02005896344780922 5.424210030469112e-05 0.00011932855704799294 0.006075525656342506 0.024458488449454308
+0.012457748875021935 9.211272117681801e-05 0.00014591355284210294 0.005166249815374613 0.0007733535021543503
+0.015483636409044266 4.062926018377766e-05 0.00011811644799308851 0.005685803480446339 0.008466921746730804
+0.018798377364873886 5.378886635298841e-05 0.00014454859774559736 0.00655730627477169 0.02036304771900177
+0.019680436700582504 4.5290034904610366e-05 0.00022722459107171744 0.007946746423840523 0.0024233912117779255
+0.025969555601477623 4.273547892807983e-05 0.0001890273269964382 0.0064910054206848145 0.025067798793315887
+0.01655467599630356 0.0003221522201783955 0.00017162678705062717 0.006645120680332184 0.020509667694568634
+0.012249505147337914 0.00010587046563159674 0.00021547311916947365 0.009208697825670242 0.004455416928976774
+0.017696604132652283 0.000433694280218333 0.00034349533962085843 0.010824386030435562 0.006948671769350767
+0.015282537788152695 0.00012832377979066223 7.51672123442404e-05 0.005875756032764912 0.003327947808429599
+0.018938563764095306 0.000641991849988699 0.0016583282267674804 0.00751475477591157 0.01053638570010662
+0.01944350264966488 4.238060864736326e-05 0.0005556781543418765 0.013627114705741405 0.023264151066541672
+0.010632883757352829 0.0020604440942406654 0.001367925782687962 0.006522743031382561 0.0004170583561062813
+0.02233830839395523 0.00021931239461991936 0.0004465327365323901 0.0029446182306855917 0.02014836110174656
+0.012800456024706364 0.002667793305590749 0.001615012064576149 0.005957028362900019 0.0013703436125069857
+0.022359095513820648 0.0005290542030707002 0.00011102972348453477 0.006756199058145285 0.015408393926918507
+0.017201239243149757 0.0015177193563431501 0.0029036435298621655 0.005959557369351387 0.010365751571953297
+0.025896485894918442 0.00018657567852642387 0.0006066171918064356 0.009474312886595726 0.00918080098927021
+0.015061190351843834 0.0015174273867160082 0.0011499690590426326 0.007452100515365601 0.009248701855540276
+0.014912809245288372 0.00047885856474749744 0.0004725187609437853 0.00442067626863718 0.006714401766657829
+0.017404576763510704 0.000753737345803529 0.00028238334925845265 0.007024487946182489 0.018117161467671394
+0.020256567746400833 0.0010445392690598965 0.00031843912438489497 0.008799910545349121 0.018669061362743378
+0.019087614491581917 6.018158092047088e-05 0.00024770505842752755 0.005594682414084673 0.021596716716885567
+0.020167451351881027 0.0008222248870879412 0.0010834417771548033 0.0043085189536213875 0.00467606820166111
+0.01911097951233387 3.5692562960321084e-05 0.00013795540144201368 0.005073487292975187 0.001205718726851046
+0.015088560059666634 0.0006930938106961548 0.0005667862133122981 0.005989000201225281 0.005336420144885778
+0.018104691058397293 0.00011800281936302781 0.00016341453010682017 0.006420508958399296 0.021491674706339836
+0.018334349617362022 0.000370227440726012 0.00021644534717779607 0.008719292469322681 0.02270558476448059
+0.014149317517876625 0.0001253156951861456 8.223004260798916e-05 0.007080836221575737 0.011815214529633522
+0.013820036314427853 0.00011113709479104728 0.0004911167197860777 0.006352521013468504 0.0024525488261133432
+0.018153224140405655 3.658523928606883e-05 0.00015322044782806188 0.005736102350056171 0.020814746618270874
+0.023057419806718826 0.00033566914498806 0.0003389575576875359 0.008865643292665482 0.019431138411164284
+0.010667871683835983 3.182714499416761e-05 6.764059799024835e-05 0.005603851284831762 0.001049803919158876
+0.014227855950593948 0.0002902266278397292 0.00018577372247818857 0.003284495323896408 0.00452487962320447
+0.020021263509988785 3.239668149035424e-05 0.00025299409753642976 0.007788355462253094 0.01671631447970867
+0.014461811631917953 3.109395402134396e-05 0.00031798845157027245 0.010894618928432465 0.0003984046634286642
+0.024194005876779556 0.00028233841294422746 0.00029771827394142747 0.0064479876309633255 0.03448544070124626
+0.014323882758617401 4.317983257351443e-05 6.993277202127501e-05 0.0071139587089419365 0.010962370783090591
+0.022585859522223473 0.0004715915711130947 0.0001937665801960975 0.0113693717867136 0.01778150163590908
+0.016188498586416245 2.3178745323093608e-05 0.00023427042469847947 0.007691795006394386 0.008022349327802658
+0.01714504510164261 8.064621943049133e-05 0.000368433422408998 0.006472112610936165 0.016542771831154823
+0.02036122977733612 0.0002742835786193609 0.00028971279971301556 0.008057910017669201 0.014681005850434303
+0.02464810013771057 7.552083843620494e-05 0.00033314002212136984 0.006862280424684286 0.0328560471534729
+0.0185831431299448 0.000431911030318588 0.0001394586724927649 0.0061820936389267445 0.01843726448714733
+0.01708040200173855 0.00021769682643935084 7.180403190432116e-05 0.009523533284664154 0.010552519001066685
+0.017710836604237556 0.00023988119210116565 0.0006377590470947325 0.006714582443237305 0.005738288164138794
+0.02241087332367897 8.655419514980167e-05 0.00020064663840457797 0.01298956386744976 0.007472369819879532
+0.014907769858837128 0.00021675370226148516 0.000272893812507391 0.007709584664553404 0.010303172282874584
+0.016474870964884758 4.3575513700488955e-05 0.0001521579542895779 0.008774563670158386 0.007341951131820679
+0.014094682410359383 0.0001278429408557713 0.0001295128313358873 0.011194455437362194 0.0021007771138101816
+0.014130533672869205 3.6355726479087025e-05 0.0002465063298586756 0.00609172647818923 0.011655847541987896
+0.022600673139095306 5.488095121108927e-05 0.0005237676086835563 0.0061714984476566315 0.03166600689291954
+0.015043665654957294 0.00018314715998712927 9.235152538167313e-05 0.00796838290989399 0.008351011201739311
+0.02093394100666046 0.00011819963401649147 5.068274185759947e-05 0.005068762693554163 0.020244058221578598
+0.02114677056670189 6.416965334210545e-05 0.00012324081035330892 0.006948379799723625 0.0193509291857481
+0.017222443595528603 6.675499025732279e-05 0.0005640379386022687 0.010351558215916157 0.0028874236159026623
+0.018320832401514053 0.0003183907538186759 0.00018075994739774615 0.006795060820877552 0.015341510996222496
+0.01516227051615715 0.0005181059241294861 7.235260272864252e-05 0.0069830818101763725 0.010576964356005192
+0.013608084060251713 4.791816900251433e-05 0.0005266234511509538 0.008458179421722889 0.00602355832234025
+0.021236982196569443 5.1624974730657414e-05 0.0002517217362765223 0.012393857352435589 0.02275082841515541
+0.02207486517727375 0.00010485521488590166 0.0007933989400044084 0.005313766188919544 0.018697621300816536
+0.013251024298369884 8.992713264888152e-05 0.0005772725562565029 0.006921487860381603 0.007947146892547607
+0.014872364699840546 0.0006227322155609727 0.00014620438741985708 0.005305248312652111 0.009340443648397923
+0.016146190464496613 0.00017982245481107384 0.0002966542378999293 0.01085782889276743 0.016296612098813057
+0.011184128001332283 2.8234851924935356e-05 0.0002779497590381652 0.006786458659917116 0.007337002083659172
+0.017574533820152283 4.262900984031148e-05 0.0001633540668990463 0.004807265475392342 0.016047492623329163
+0.013714568689465523 7.972688035806641e-05 0.000383790087653324 0.004205755889415741 0.004037333652377129
+0.020091228187084198 0.0004293603706173599 0.00011642670870060101 0.003488150890916586 0.021581294015049934
+0.015886861830949783 0.00019191906903870404 0.00023086844885256141 0.0046821460127830505 0.007315218448638916
+0.013875089585781097 6.248133286135271e-05 0.0005657925503328443 0.00637053232640028 0.00932911317795515
+0.01672450453042984 5.5355521908495575e-05 0.00025919193285517395 0.00962015800178051 0.006522179115563631
+0.017274318262934685 0.00013675248192157596 0.0004982991958968341 0.009308080188930035 0.0039014448411762714
+0.013614954426884651 6.448687054216862e-05 0.0002674334100447595 0.009655522182583809 0.0024269986897706985
+0.019406991079449654 0.0006400015554390848 0.00017498002853244543 0.011062601581215858 0.022697368636727333
+0.021182723343372345 3.2616197131574154e-05 0.0002154904359485954 0.01000475324690342 0.025905685499310493
+0.020965535193681717 0.0004627778544090688 0.0009656046167947352 0.0064941528253257275 0.024877777323126793
+0.015941444784402847 2.3312901248573326e-05 9.039048745762557e-05 0.003708748845383525 0.014433544129133224
+0.019140684977173805 0.0006230900762602687 0.000819062814116478 0.006120904348790646 0.0073091378435492516
+0.01411048136651516 2.8393842512741685e-05 0.000158495080540888 0.004765185993164778 0.005728918127715588
+0.012045788578689098 0.00037545248051173985 7.566702697658911e-05 0.004323685076087713 0.004603198729455471
+0.017597243189811707 5.3903677326161414e-05 0.0002041001571342349 0.009870978072285652 0.014317966997623444
+0.01683827117085457 1.9621296814875677e-05 0.0001478559133829549 0.007074629422277212 0.006254887208342552
+0.012677069753408432 2.3495531422668137e-05 6.5930827986449e-05 0.0075137196108698845 0.005765099078416824
+0.022118201479315758 8.072196214925498e-05 0.00010541417577769607 0.00999145582318306 0.022547846660017967
+0.017780927941203117 2.8957691029063426e-05 0.0001931230362970382 0.012506172060966492 0.011738032102584839
+0.017752204090356827 0.00026174067170359194 0.0011627840576693416 0.0054048653692007065 0.014278830960392952
+0.022305242717266083 0.0002051499905064702 0.00011602255835896358 0.006317176856100559 0.005495891906321049
+0.017676450312137604 0.0013990062288939953 0.0006801463314332068 0.005939601454883814 0.012139064259827137
+0.02336718887090683 2.120493991242256e-05 5.6944805692182854e-05 0.01218579150736332 0.015057597309350967
+0.02129012905061245 0.001701013301499188 0.0011919827666133642 0.008042865432798862 0.01734677329659462
+0.021233655512332916 2.1829164325026795e-05 0.0001686459727352485 0.0061680893413722515 0.024322230368852615
+0.016694461926817894 0.001245406805537641 0.0030344678089022636 0.011741340160369873 0.008158768527209759
+0.018761228770017624 0.0003685362753458321 0.0001560409291414544 0.006907463073730469 0.015900863334536552
+0.02304537408053875 0.0016313308151438832 0.0016873144777491689 0.0069250850938260555 0.02616562321782112
+0.01667507365345955 0.0007039668853394687 0.0005790867726318538 0.010246788151562214 0.01797885075211525
+0.018313199281692505 0.0014989713672548532 0.0010130577720701694 0.004388253204524517 0.013680855743587017
+0.01867295801639557 0.0003252051246818155 0.0008252649568021297 0.005070454441010952 0.0028852622490376234
+0.02067004330456257 0.000342427083523944 0.0004637925303541124 0.00600811280310154 0.011446376331150532
+0.016895152628421783 0.0004532494058366865 0.0009242280502803624 0.008545752614736557 0.002977039199322462
+0.017682796344161034 7.41688854759559e-05 0.0005295933806337416 0.008037133142352104 0.009393040090799332
+0.017869075760245323 0.0011290963739156723 0.0008222744218073785 0.009069430641829967 0.0118551105260849
+0.019053397700190544 0.0003021530283149332 0.00049378233961761 0.008677939884364605 0.027301905676722527
+0.02453574910759926 0.0006179057527333498 0.00042918624239973724 0.013512545265257359 0.029779566451907158
+0.014733848161995411 0.000279164727544412 9.260347724193707e-05 0.008444814011454582 0.009707150049507618
+0.01785968989133835 6.411084905266762e-05 0.0009165636147372425 0.007103945128619671 0.0054419925436377525
+0.017877619713544846 0.0002508914330974221 0.00011506998271215707 0.004026274662464857 0.0054299673065543175
+0.02238764800131321 0.000517488457262516 0.0010824011405929923 0.00826985202729702 0.011988160200417042
+0.01703709363937378 0.000604770437348634 6.798256799811497e-05 0.006629989016801119 0.012404980137944221
+0.01580154523253441 0.0016390671953558922 0.00041175654041580856 0.003716929815709591 0.014872025698423386
+0.011610572226345539 8.996785618364811e-05 0.0006920986343175173 0.011379503645002842 0.0015057023847475648
+0.013835614547133446 0.0006393701187334955 0.0009916797280311584 0.008931388147175312 0.009456844069063663
+0.02057114616036415 0.0001840815821196884 0.00041812192648649216 0.007514395751059055 0.016325537115335464
+0.014797566458582878 0.0002832527388818562 0.0010137812932953238 0.008087562397122383 0.0031905213836580515
+0.017356283962726593 0.0003490492235869169 8.395045733777806e-05 0.007440654095262289 0.008635175414383411
+0.014591949060559273 0.0006184991798363626 8.649923984194174e-05 0.005723105277866125 0.011939935386180878
+0.01601581647992134 3.067262514377944e-05 0.00018664993694983423 0.004181408789008856 0.013367431238293648
+0.01331390906125307 0.0001762247848091647 0.00017212021339219064 0.005293808877468109 0.0023213643580675125
+0.011557872407138348 3.144582296954468e-05 0.0001316538982791826 0.006315294653177261 0.004180456046015024
+0.019614294171333313 4.3290820030961186e-05 0.00044691594666801393 0.0037996971514075994 0.009971246123313904
+0.014778311364352703 6.420155114028603e-05 0.00012802898709196597 0.006592869758605957 0.006990272086113691
+0.011726127937436104 0.0002329572307644412 0.000130901433294639 0.007131629157811403 0.006854258943349123
+0.015892615541815758 3.900022420566529e-05 0.00011036327487090603 0.0107338335365057 0.015384301543235779
+0.023666968569159508 0.00015377302770502865 0.00027346593560650945 0.009043778292834759 0.018592823296785355
+0.01232406310737133 3.23884523822926e-05 0.00011955879017477855 0.006137817166745663 0.00044744001934304833
+0.016222801059484482 2.0555569790303707e-05 0.00017532719357404858 0.0059594642370939255 0.01087462529540062
+0.019333817064762115 0.00012592901475727558 0.00014635176921728998 0.006923841778188944 0.01561126671731472
+0.021037323400378227 2.7996991775580682e-05 0.00012694840552285314 0.010420288890600204 0.02007221058011055
+0.017537111416459084 0.00047425198135897517 0.0004486334219109267 0.008252338506281376 0.017189770936965942
+0.013331573456525803 2.9612780053867027e-05 6.869374919915572e-05 0.008813533931970596 0.008183368481695652
+0.015408819541335106 0.0003548863751348108 0.00022205822460819036 0.00813178438693285 0.0013940311037003994
+0.01384100690484047 6.737768126185983e-05 5.6990436860360205e-05 0.007201199419796467 0.011935156770050526
+0.017144959419965744 5.210149902268313e-05 0.00021221156930550933 0.007144498638808727 0.014853318221867085
+0.016468621790409088 1.878028706414625e-05 4.9795205995906144e-05 0.008013345301151276 0.016633205115795135
+0.016333647072315216 7.038112380541861e-05 0.00010917628969764337 0.004200716502964497 0.012265659868717194
+0.01150287315249443 2.6340650947531685e-05 0.00012798994430340827 0.004353234078735113 0.0056424387730658054
+0.023425748571753502 8.281949703814462e-05 6.799156835768372e-05 0.005301494151353836 0.012767566367983818
+0.022296488285064697 0.0001444734079996124 5.378809873946011e-05 0.004122383426874876 0.0005486660520546138
+0.012440910562872887 3.2027208362706006e-05 8.392997551709414e-05 0.003763179061934352 0.006711865309625864
+0.014551874250173569 8.519193215761334e-05 0.00022047701349947602 0.009183648973703384 0.01237143948674202
+0.01836012303829193 3.066433782805689e-05 0.000344239262631163 0.010886390693485737 0.0176934152841568
+0.0156625397503376 5.226062785368413e-05 0.0005429940647445619 0.008561008609831333 0.0056931073777377605
+0.023367702960968018 6.635823956457898e-05 0.00015924949548207223 0.007173965685069561 0.03782831132411957
+0.01353748980909586 0.00027763479738496244 8.885205170372501e-05 0.006088306661695242 0.006012545898556709
+0.01423054002225399 3.409326382097788e-05 3.763291897485033e-05 0.004144180566072464 0.011260504834353924
+0.01273376028984785 0.00011699661263264716 3.4574270102893934e-05 0.0038310675881803036 0.0005103740841150284
+0.016263285651803017 7.224986620713025e-05 8.925220026867464e-05 0.005063355900347233 0.01169146690517664
+0.014865431934595108 3.291710527264513e-05 0.00022409253870137036 0.00379933905787766 0.00868489220738411
+0.01162786316126585 5.145326576894149e-05 0.00015623572107870132 0.006794607732445002 0.001922846888191998
+0.016244791448116302 0.00010700818529585376 0.00010021127673098817 0.007647423073649406 0.01483129896223545
+0.018899600952863693 3.5048171412199736e-05 0.0002493061765562743 0.011030877009034157 0.018890483304858208
+0.021067101508378983 0.00010805961210280657 0.00023854173196014017 0.01078036893159151 0.004687774460762739
+0.0159815214574337 8.839664224069566e-05 0.00012622041685972363 0.00637351768091321 0.016256310045719147
+0.014774501323699951 6.69949731673114e-05 8.672502008266747e-05 0.008959074504673481 0.014858740381896496
+0.016888774931430817 5.0509683205746114e-05 0.00010005399963119999 0.011614413000643253 0.01215285062789917
+0.015482891350984573 8.22654037619941e-05 0.00010945866961264983 0.002755642868578434 0.009674006141722202
+0.02416282147169113 0.0001019047704176046 0.0002532132202759385 0.0031065845396369696 0.023480527102947235
+0.015480218455195427 0.0001331513631157577 0.0005868927692063153 0.006785091944038868 0.006735215894877911
+0.019839972257614136 0.0004902895889244974 7.677488611079752e-05 0.003988949581980705 0.016090158373117447
+0.010762894526124 0.00030724029056727886 8.28997217467986e-05 0.00636526383459568 0.003516939003020525
+0.025883624330163002 0.0001903139054775238 0.0004437939787749201 0.010422262363135815 0.026554768905043602
+0.01284402422606945 0.00017739032045938075 0.00023054590565152466 0.010126526467502117 0.008603201247751713
+0.01864643022418022 0.0002519557601772249 0.0005310788983479142 0.007121979724615812 0.017141759395599365
+0.016902746632695198 5.9072750445920974e-05 0.0003662144299596548 0.006006117910146713 0.008930119685828686
+0.017346790060400963 0.00044923528912477195 0.00013871690316591412 0.007488989271223545 0.00021481330622918904
+0.01912793144583702 8.844315016176552e-05 3.691824167617597e-05 0.018530769273638725 0.016675462946295738
+0.01640963926911354 0.00013052683789283037 0.00027618950116448104 0.004450349602848291 0.015346131287515163
+0.023241277784109116 7.080971408868209e-05 0.00010996383934980258 0.00783943384885788 0.02554365247488022
+0.014640185050666332 3.714995546033606e-05 3.7653109757229686e-05 0.004928336478769779 0.012812593020498753
+0.017871104180812836 0.0001804550556698814 0.0002254777937196195 0.00411166250705719 0.008396584540605545
+0.013953963294625282 8.040050306590274e-05 0.00041691146907396615 0.008820553310215473 0.011347384192049503
+0.016235200688242912 0.0004346208879724145 7.582313992315903e-05 0.009061189368367195 0.00690454663708806
+0.022912027314305305 0.0004234088701196015 0.0001713499950710684 0.0034030091483145952 0.023676561191678047
+0.013196593150496483 0.00034981913631781936 0.0007792956312187016 0.006323142442852259 0.0021544802002608776
+0.015073521994054317 0.00016214938659686595 0.0006036943523213267 0.009738362394273281 0.014243436977267265
+0.015989726409316063 0.0009126189397647977 0.0013832474360242486 0.007062586955726147 0.011578137055039406
+0.015505068004131317 0.00018517761782277375 0.001081370748579502 0.006919680163264275 0.013573586009442806
+0.01099821925163269 0.0018200791673734784 0.0006000472349114716 0.007070348132401705 0.001942066359333694
+0.020919593051075935 0.0008360195206478238 0.00019234413048252463 0.008528191596269608 0.017947133630514145
+0.020599152892827988 0.00039835317875258625 0.0006246494594961405 0.0058434875681996346 0.019968951120972633
+0.012896351516246796 0.0003044742625206709 9.53769194893539e-05 0.006998015101999044 0.003260016907006502
+0.012913571670651436 3.342266791150905e-05 0.0006262342212721705 0.008353637531399727 0.006811344996094704
+0.01960049942135811 8.051757322391495e-05 0.00014745317457709461 0.00834712851792574 0.020976494997739792
+0.023035988211631775 0.0001776899880496785 0.00010801569442264736 0.0051986477337777615 0.03643658384680748
+0.016685577109456062 4.427179374033585e-05 0.00010866016964428127 0.009240721352398396 0.01558266207575798
+0.022467680275440216 0.00010914278391283005 8.788780542090535e-05 0.005746693350374699 0.03404329717159271
+0.020562240853905678 0.0004409085086081177 6.865399336675182e-05 0.00398633349686861 0.022836104035377502
+0.016620635986328125 0.0002574983227532357 0.000262833054875955 0.004847630392760038 0.011039522476494312
+0.01466667465865612 0.0003370011691004038 0.00011497547529870644 0.006410737056285143 0.009924881160259247
+0.015991225838661194 0.0005119999987073243 8.725637599127367e-05 0.0054167709313333035 0.010189301334321499
+0.013187840580940247 3.296621434856206e-05 0.00045077691902406514 0.015626907348632812 0.0038349423557519913
+0.01672094501554966 7.504218228859827e-05 0.00033799398806877434 0.007996904663741589 0.011397399939596653
+0.022802051156759262 0.00021432689391076565 0.00039679405745118856 0.008464380167424679 0.02323080226778984
+0.02066998928785324 7.080723298713565e-05 0.0006599767948500812 0.008292271755635738 0.023506157100200653
+0.016099005937576294 0.0002497615641914308 0.00021353467309381813 0.00909390114247799 0.01572032831609249
+0.015568559058010578 0.0003200150385964662 0.00022524053929373622 0.005150910001248121 0.008737926371395588
+0.015256281942129135 0.00010005588410422206 0.00039854063652455807 0.004694782663136721 0.013807009905576706
+0.016176607459783554 0.00023482424148824066 0.00017670360102783889 0.005063591059297323 0.006419619079679251
+0.03406058996915817 4.1235114622395486e-05 0.0005427611758932471 0.003913921304047108 0.023859437555074692
+0.02409413643181324 0.0002524509618524462 0.0005796598852612078 0.005492580123245716 0.03461764380335808
+0.019693676382303238 0.0007716439431533217 0.0005222741165198386 0.011896984651684761 0.007816172204911709
+0.014413353055715561 0.0008551154751330614 0.00013767332711722702 0.011342888697981834 0.0033701849170029163
+0.01779169589281082 7.252879004226997e-05 5.225676068221219e-05 0.010022744536399841 0.012736835516989231
+0.018957991153001785 0.0003734907368198037 8.36094404803589e-05 0.010505717247724533 0.0233998391777277
+0.015053669922053814 0.0003422461450099945 0.0005139572895132005 0.010907331481575966 0.012898867018520832
+0.013545488007366657 4.6036158892093226e-05 0.0017683074111118913 0.009300597943365574 0.005345536861568689
+0.014671847224235535 0.0004630043695215136 0.00033624956267885864 0.006879238411784172 0.009765894152224064
+0.022949201986193657 0.0009645214304327965 0.0004099008801858872 0.004324484616518021 0.004684478044509888
+0.02274158224463463 8.072744094533846e-05 0.00012909028737340122 0.003191424533724785 0.009908434003591537
+0.01211676187813282 0.0012401521671563387 0.0001426990347681567 0.005937555339187384 0.0005063583375886083
+0.01701807603240013 0.0006077673751860857 0.000908088346477598 0.010132204741239548 0.012286819517612457
+0.021016037091612816 0.00017228444630745798 0.0012517892755568027 0.011763331480324268 0.024746768176555634
+0.023554792627692223 0.00018602493219077587 0.0004451847053132951 0.011047394014894962 0.03432558849453926
+0.02331320010125637 9.283314284402877e-05 0.00035711866803467274 0.013439929112792015 0.03072129748761654
+0.022052839398384094 0.0002580992877483368 0.0003220280632376671 0.007613126188516617 0.010364583693444729
+0.018503393977880478 0.000899324135389179 7.743440801277757e-05 0.004079185891896486 0.014503358863294125
+0.015482333488762379 0.0010279110865667462 0.00024828340974636376 0.009523740038275719 0.006272758357226849
+0.01984291896224022 0.0004264865710865706 0.00031341493013314903 0.00872910488396883 0.004066010005772114
+0.01646689511835575 0.0018196196760982275 0.0009369516046717763 0.007273092400282621 0.01251675933599472
+0.02439049258828163 0.00013051132555119693 0.0005621879827231169 0.010535009205341339 0.02407201938331127
+0.023198101669549942 0.0006676584016531706 0.0014460351085290313 0.01838498003780842 0.022000975906848907
+0.015310382470488548 0.0006738697411492467 0.000848027179017663 0.017258349806070328 0.009001914411783218
+0.02253650687634945 0.0009048799402080476 0.0020780207123607397 0.02129802294075489 0.019819237291812897
+0.01820923201739788 0.0014846465783193707 0.0003111587429884821 0.007136348634958267 0.015046299435198307
+0.01757979951798916 0.002761256415396929 0.0014424622058868408 0.00564689701423049 0.0125473253428936
+0.014477482065558434 0.00040123023791238666 0.00045296174357645214 0.005975847598165274 0.00042646718793548644
+0.019343676045536995 0.002076509641483426 0.0012083473848178983 0.003986361902207136 0.012999405153095722
+0.02063998207449913 0.00018174335127696395 0.0008615771657787263 0.0064933327957987785 0.018804354593157768
+0.02138550765812397 0.0004776765126734972 0.0013501718640327454 0.007086829282343388 0.011940675787627697
+0.01871170662343502 0.0006372728385031223 0.0005148763302713633 0.005502159707248211 0.016813641414046288
+0.012186747044324875 0.0002129244530806318 0.000380251178285107 0.009080111049115658 0.004942948929965496
+0.021299835294485092 0.0009535606950521469 0.0004622679844032973 0.010216432623565197 0.023143956437706947
+0.019011694937944412 0.00023727280495222658 7.291923247976229e-05 0.005181873217225075 0.026524432003498077
+0.024249911308288574 8.443387923762202e-05 0.0002385899715591222 0.005022794008255005 0.037162479013204575
+0.014355218037962914 5.338047776604071e-05 0.00023527421581093222 0.006117034703493118 0.00401674397289753
+0.014674253761768341 5.27202028024476e-05 0.0002025961148319766 0.006937147583812475 0.0038432879373431206
+0.01713067665696144 0.0003121178306173533 0.00012874016829300672 0.0048021660186350346 0.005247517954558134
+0.017475219443440437 3.445069887675345e-05 0.000520107860211283 0.00955596286803484 0.009704620577394962
+0.01454588957130909 3.5349639802007005e-05 0.0003487278881948441 0.012998727150261402 0.011590925976634026
+0.014525378122925758 0.00023407628759741783 0.0008036770159378648 0.013127614744007587 0.007318167015910149
+0.013098386116325855 2.4247985493275337e-05 0.0009104482014663517 0.006557153537869453 0.003657693974673748
+0.013473174534738064 0.0014436283381655812 0.0001665346062509343 0.006256199907511473 0.0069443900138139725
+0.01748194731771946 0.0006337261293083429 0.000107667910924647 0.008723000064492226 0.01037826668471098
+0.020328380167484283 0.000634059077128768 0.0009359902469441295 0.005206826608628035 0.0206264890730381
+0.016859270632267 0.0002703650388866663 0.00010159210069105029 0.0054770237766206264 0.013049724511802197
+0.016025306656956673 0.00040287707815878093 0.001705161645077169 0.006641595158725977 0.014961019158363342
+0.012740494683384895 3.741586988326162e-05 0.0005080295377410948 0.005303655285388231 0.0012500189477577806
+0.012665558606386185 0.0008207783685065806 0.00010242850839858875 0.007386334706097841 0.0019480487098917365
+0.01413758099079132 0.00010301871225237846 7.920702773844823e-05 0.005751083604991436 0.000254566315561533
+0.013327283784747124 0.00024451647186651826 0.00033124021138064563 0.007190181873738766 0.0071567511186003685
+0.016985878348350525 0.00012837447866331786 7.673244545003399e-05 0.008263865485787392 0.007783855777233839
+0.014057359658181667 7.66367738833651e-05 0.00046519117313437164 0.014488511718809605 0.0019184182165190578
+0.02289358153939247 0.00013412807311397046 0.0009985838551074266 0.009610985405743122 0.029236802831292152
+0.020126724615693092 0.00019303425506222993 0.00018399866530671716 0.0045313783921301365 0.023100148886442184
+0.019036345183849335 0.0006826443132013083 0.00011100450501544401 0.006104703992605209 0.021809957921504974
+0.02036287635564804 8.311757846968248e-05 0.00014711842231918126 0.004886016249656677 0.02312065288424492
+0.021319396793842316 0.00026072660693898797 0.00011474219354568049 0.004932511132210493 0.010059714317321777
+0.017280731350183487 7.597938383696601e-05 0.00038938247598707676 0.0037334139924496412 0.011372746899724007
+0.0195707269012928 5.586106635746546e-05 0.0009014597744680941 0.003423243062570691 0.009504248388111591
+0.018822625279426575 0.00026620461721904576 0.00027401986881159246 0.0050861830823123455 0.01393928937613964
+0.013471152633428574 0.0004670910711865872 0.0002643801271915436 0.013662041164934635 0.00789058580994606
+0.015226950868964195 0.00019067229004576802 0.0005132905207574368 0.011806447990238667 0.009658192284405231
+0.02994065172970295 5.9066631365567446e-05 0.0002909887407440692 0.009568440727889538 0.05313517898321152
+0.026490824297070503 0.0006917198188602924 0.00012546306243166327 0.009457755833864212 0.0407261848449707
+0.01257534883916378 0.00010843157360795885 0.0005989063647575676 0.005480408668518066 0.009984447620809078
+0.026604149490594864 5.355080065783113e-05 0.001170141389593482 0.004122598096728325 0.027170248329639435
+0.019165169447660446 0.0005864204722456634 0.00021008703333791345 0.005004365462809801 0.0007343339966610074
+0.019359631463885307 0.00031980633502826095 0.0008561995346099138 0.003936576656997204 0.010852784849703312
+0.018117394298315048 0.0009319116943515837 5.3374118579085916e-05 0.00644700089469552 0.007294660434126854
+0.015156477689743042 0.0016437589656561613 0.00018663349328562617 0.008552515879273415 0.007011236622929573
+0.013936587609350681 5.504667205968872e-05 0.0008919183746911585 0.012975827790796757 0.00740901380777359
+0.024420633912086487 0.00024090151418931782 0.0006171524873934686 0.011427721939980984 0.03574208542704582
+0.017503952607512474 3.0019040423212573e-05 0.0005159910069778562 0.01636989787220955 0.019736766815185547
+0.022629117593169212 0.00013913793372921646 0.002151314401999116 0.009069925174117088 0.03372243791818619
+0.023000523447990417 0.0010086115216836333 0.00010362040484324098 0.004614724311977625 0.020841026678681374
+0.024364640936255455 0.0023711787071079016 0.0007033309084363282 0.0043597049079835415 0.010215734131634235
+0.018101878464221954 5.2419105486478657e-05 6.96199931553565e-05 0.005247439723461866 0.011564591899514198
+0.015773439779877663 0.00134125049225986 0.00013538221537601203 0.003086056327447295 0.0076953088864684105
+0.022179976105690002 0.00015326928405556828 0.0014079109532758594 0.006074862089008093 0.01421757135540247
+0.022005198523402214 4.7108831495279446e-05 0.002368348650634289 0.009762248024344444 0.014682170003652573
+0.018043268471956253 0.0007477918406948447 0.0006140977493487298 0.020181044936180115 0.01473400741815567
+0.012583757750689983 0.0008614095859229565 0.0012791132321581244 0.021281573921442032 0.003560785437002778
+0.02263224497437477 0.00026643837918527424 0.00046924446360208094 0.013696671463549137 0.028874553740024567
+0.017254695296287537 0.002013254212215543 0.00012493456597439945 0.006208885461091995 0.016912469640374184
+0.019820887595415115 0.0005392428138293326 0.00097910116892308 0.008526254445314407 0.016581255942583084
+0.01979774981737137 2.5625922717154026e-05 0.0018359103705734015 0.005280361510813236 0.015503433533012867
+0.020590269938111305 0.0002437014481984079 3.5738958104047924e-05 0.004001423250883818 0.020806221291422844
+0.01893683522939682 0.0002869645250029862 0.001169556169770658 0.004450357519090176 0.007155937142670155
+0.020111776888370514 0.0008977894904091954 0.00011646901111816987 0.007058531511574984 0.013389143161475658
+0.018453828990459442 0.0008645659545436502 0.0005031627370044589 0.009782341308891773 0.014175725169479847
+0.016922341659665108 0.00014726117660757154 0.0029575289227068424 0.016688989475369453 0.00780381727963686
+0.027014045044779778 0.0004251272766850889 0.00047632333007641137 0.017558546736836433 0.04094623401761055
+0.013756811618804932 0.0009145936346612871 0.0008485916769132018 0.009557516314089298 0.010501701384782791
+0.014614882878959179 0.00025132670998573303 0.00022891932167112827 0.008152212016284466 0.002216618275269866
+0.02381938323378563 0.0017415358452126384 6.161784403957427e-05 0.006764697842299938 0.0307258702814579
+0.015404261648654938 0.0003146312083117664 0.0009466730989515781 0.006161128170788288 0.005843627732247114
+0.015878163278102875 5.2086892537772655e-05 0.0016005221987143159 0.0027987537905573845 0.008869355544447899
+0.017764633521437645 0.00018234059098176658 5.582338781096041e-05 0.004588339477777481 0.016003279015421867
+0.015222707763314247 0.0002071110502583906 0.000314692675601691 0.004143614787608385 0.004846334457397461
+0.019925309345126152 0.0003129343967884779 0.00012490125664044172 0.0040578884072601795 0.01528829988092184
+0.012067792937159538 0.0006490862579084933 0.00015991511463653296 0.008692286908626556 0.0030757810454815626
+0.018255341798067093 0.0001245185121661052 0.0007000070763751864 0.011605066247284412 0.021424707025289536
+0.022986596450209618 6.009953358443454e-05 0.0004897542530670762 0.011090508662164211 0.034046269953250885
+0.014144973829388618 6.62119928165339e-05 0.00015247263945639133 0.009660658426582813 0.012394113466143608
+0.01738155633211136 3.707601354108192e-05 0.00024325815320480615 0.007096387445926666 0.01303523126989603
+0.01624954678118229 0.0002318705664947629 6.632388976868242e-05 0.006165269296616316 0.010915867052972317
+0.014389934949576855 0.0004111594462301582 4.954492760589346e-05 0.004211014602333307 0.005226787645369768
+0.013301558792591095 4.0331153286388144e-05 0.00011474374332465231 0.004306699614971876 0.0017399393254891038
+0.015453542582690716 0.00010983451647916809 0.00022485408408101648 0.005099967587739229 0.012966718524694443
+0.017376499250531197 5.332063665264286e-05 0.0001505782565800473 0.011134568601846695 0.020359763875603676
+0.019290132448077202 5.082425559521653e-05 0.00011071460176026449 0.005326311569660902 0.020275160670280457
+0.01821395754814148 3.400407877052203e-05 0.00014026951976120472 0.006931346375495195 0.02266654185950756
+0.01705746352672577 6.306356226559728e-05 0.0001911185827339068 0.0078047458082437515 0.013774586841464043
+0.020649436861276627 3.369095429661684e-05 7.396042929030955e-05 0.0056793480180203915 0.018182950094342232
+0.01862792856991291 2.116736322932411e-05 8.076890662778169e-05 0.004760189913213253 0.014851651154458523
+0.01846185512840748 4.694303061114624e-05 0.00012659415369853377 0.006297948304563761 0.008724616840481758
+0.01340961642563343 5.0724986067507416e-05 9.552728442940861e-05 0.005811724346131086 0.0026997809763997793
+0.009971577674150467 8.653372788103297e-05 0.00014323255163617432 0.004751645028591156 0.0025651731994003057
+0.021511251106858253 4.465166057343595e-05 6.068034781492315e-05 0.011288230307400227 0.0259128175675869
+0.023672636598348618 2.5125642423518002e-05 0.00016015181608963758 0.01678679697215557 0.027566753327846527
+0.012876631692051888 7.680282578803599e-05 0.0007646470330655575 0.007646386977285147 0.01023530587553978
+0.01366677787154913 0.000292851502308622 7.719836139585823e-05 0.00787422340363264 0.012007338926196098
+0.015529327094554901 0.001018497976474464 0.0007514620083384216 0.008255610242486 0.011257560923695564
+0.023578910157084465 0.00010002704948419705 4.30124782724306e-05 0.004777312278747559 0.016145648434758186
+0.01583232544362545 0.0008726783562451601 0.00011114415974589065 0.0052382429130375385 0.0003270100278314203
+0.020215483382344246 3.799065962084569e-05 0.001088542165234685 0.006322382017970085 0.006512513849884272
+0.011664294637739658 4.1309445805381984e-05 0.0011138179106637836 0.01672147773206234 0.0014562252908945084
+0.01578550413250923 0.0009531003306619823 0.00035160561674274504 0.006705587729811668 0.009345992468297482
+0.017995202913880348 0.0004049361450597644 0.0009274880867451429 0.00836880225688219 0.017413988709449768
+0.019253969192504883 0.0007980391965247691 0.00014337232278194278 0.010193604975938797 0.023501772433519363
+0.014523060992360115 0.0013931433204561472 0.00010534502507653087 0.006836037151515484 0.011901958845555782
+0.02109157107770443 0.00010423339699627832 0.0007995889172889292 0.0058770012110471725 0.01662902720272541
+0.013304339721798897 2.1154488422325812e-05 0.00030542424065060914 0.004676830023527145 0.0031085277441889048
+0.02413148246705532 0.0001288151106564328 0.0004312306409701705 0.002681537764146924 0.031304746866226196
+0.02152360789477825 3.8896079786354676e-05 0.0009759258828125894 0.00635555712506175 0.007853955961763859
+0.017293576151132584 0.0014264172641560435 5.937893729424104e-05 0.005190675146877766 0.0109104597941041
+0.015079332515597343 0.0009894708637148142 4.186115984339267e-05 0.005343763157725334 0.005053117871284485
+0.011802232824265957 0.0001369924721075222 0.0005699793691746891 0.006828537676483393 0.006311762612313032
+0.016997430473566055 0.0002863590489141643 0.0002237270527984947 0.012316589243710041 0.01919056475162506
+0.0242543276399374 4.977708522346802e-05 0.0013465352822095156 0.012078268453478813 0.027964415028691292
+0.018901851028203964 4.1673782106954604e-05 0.0015293046599254012 0.00927029736340046 0.015386526472866535
+0.017738912254571915 0.0014111676719039679 6.273404142120853e-05 0.0052698287181556225 0.005100148729979992
+0.018658500164747238 0.0010398569284006953 0.0003018606803379953 0.004343239590525627 0.01901622675359249
+0.012697539292275906 0.0003026414487976581 8.35708124213852e-05 0.0039594778791069984 0.0004633019270841032
+0.015933029353618622 0.001552476198412478 0.0001434714940842241 0.004350544419139624 0.00244697742164135
+0.01252038311213255 0.00014851060404907912 0.0006525717908516526 0.005852897185832262 0.0074361776933074
+0.018965944647789 5.446372597361915e-05 0.0009691866580396891 0.006751815788447857 0.02631978876888752
+0.014172274619340897 0.00012046965275658295 0.00017939205281436443 0.006577981170266867 0.012036471627652645
+0.014716598205268383 6.241939263418317e-05 0.0003723606059793383 0.01136218011379242 0.011455342173576355
+0.012039056979119778 0.0003726287104655057 5.4377716878661886e-05 0.007445116061717272 0.006819391623139381
+0.024094941094517708 0.0009449712815694511 0.00015077243733685464 0.0072189196944236755 0.016085442155599594
+0.026479938998818398 6.580464832950383e-05 0.0001036824396578595 0.019465893507003784 0.014008465223014355
+0.01888321340084076 0.0004743935714941472 0.00021318052313290536 0.005909716244786978 0.019099952653050423
+0.012896589934825897 0.0003978122549597174 0.00023993696959223598 0.007987198419868946 0.007191954180598259
+0.012347673065960407 4.0925278881331906e-05 7.168387674028054e-05 0.0062272739596664906 0.000590478302910924
+0.01837236061692238 5.196852725930512e-05 6.543378549395129e-05 0.00443292735144496 0.02082795277237892
+0.01772543229162693 0.00013294427481014282 9.587033855495974e-05 0.00700031453743577 0.02191517874598503
+0.015056002885103226 3.7858320865780115e-05 0.00026377258473075926 0.010805344209074974 0.010212494991719723
+0.015835091471672058 2.4759014195296913e-05 0.0003444513422437012 0.008493251167237759 0.009511495009064674
+0.016719844192266464 0.000594496785197407 0.0004672806244343519 0.004804596770554781 0.012783833779394627
+0.02010297402739525 0.00012966404028702527 0.0008931206539273262 0.006570173893123865 0.023590777069330215
+0.01364006195217371 0.0013815517304465175 0.00011555496894288808 0.006659595295786858 0.001371593913063407
+0.015506923198699951 0.0014541231794282794 0.00031468551605939865 0.0035261011216789484 0.008676452562212944
+0.016195975244045258 0.0003365128068253398 0.0010581623064354062 0.00429871492087841 0.0016845512436702847
+0.022065090015530586 0.0007484719390049577 0.0009507212671451271 0.009641796350479126 0.01645771786570549
+0.01844143122434616 0.0004040680069010705 0.0012459170538932085 0.009809068404138088 0.020281672477722168
+0.01954733580350876 0.0005936856032349169 0.002729499014094472 0.008890682831406593 0.0039628189988434315
+0.01344364695250988 0.0010941592045128345 0.0001640951377339661 0.006680366117507219 0.004674864467233419
+0.02321830950677395 0.0030983055476099253 0.0010393692646175623 0.0031462605111300945 0.01675589196383953
+0.014916755259037018 2.9300010282895528e-05 0.0002366047992836684 0.006364098284393549 0.014677969738841057
+0.0154594536870718 0.0033165467903017998 0.001456585479900241 0.01089437399059534 0.0031616929918527603
+0.020637016743421555 0.0011777556501328945 0.0005259288009256124 0.004889951553195715 0.018333984538912773
+0.019163722172379494 0.0016058672918006778 0.002710553817451 0.003990732133388519 0.013572349213063717
+0.023758485913276672 0.0009291901951655746 0.0014521017437800765 0.005679067224264145 0.017330702394247055
+0.015929611399769783 0.0007361466414295137 0.000542151799891144 0.01474345289170742 0.000684652419295162
+0.019895680248737335 0.0012292125029489398 0.0010685373563319445 0.014625298790633678 0.020904259756207466
+0.026556722819805145 0.00035398436011746526 0.0002667822118382901 0.013560892082750797 0.04001423716545105
+0.01860128343105316 0.0014077486703172326 0.0008584224269725382 0.008561281487345695 0.01925014704465866
+0.017810514196753502 9.592313290340826e-05 0.0004702691803686321 0.00926091056317091 0.01665414124727249
+0.017659923061728477 0.0009710645535960793 0.0018951187375932932 0.008579679764807224 0.01290140487253666
+0.01891385205090046 0.0001148610026575625 0.00010591550380922854 0.005514712538570166 0.011438032612204552
+0.024138042703270912 0.0018341970862820745 0.0012678767088800669 0.003850230248644948 0.009712588042020798
+0.024285946041345596 0.00018700910732150078 0.00038232814404182136 0.004281248431652784 0.009488632902503014
+0.019327322021126747 0.0010696789249777794 0.00044983281986787915 0.008235855028033257 0.014382061548531055
+0.01886674016714096 0.0011235892307013273 0.0006987368105910718 0.01719360239803791 0.013173185288906097
+0.01444222405552864 0.00020864064572378993 0.0008004461997188628 0.010156266391277313 0.0088240010663867
+0.024187251925468445 0.00035952741745859385 0.0007845931104384363 0.011953707784414291 0.026359735056757927
+0.01927529089152813 0.00016589928418397903 0.00029461595113389194 0.012946544215083122 0.021284883841872215
+0.021122410893440247 0.00013133652100805193 0.0012578319292515516 0.007590315770357847 0.029051635414361954
+0.019715365022420883 8.520315168425441e-05 0.0005662900512106717 0.004877651110291481 0.010708056390285492
+0.0179427620023489 0.0015979191521182656 0.0005050859763287008 0.01552007719874382 0.0009778998792171478
+0.019733447581529617 0.0006402476574294269 0.0010804061312228441 0.004615932237356901 0.011029024608433247
+0.016121134161949158 0.0014528600731864572 0.00012096478167222813 0.006669805850833654 0.006228671874850988
+0.017716526985168457 0.002514478750526905 0.0003118660824839026 0.008672804571688175 0.013339674100279808
+0.01606064848601818 5.7463254051981494e-05 0.0008531819912604988 0.012263310141861439 0.013655873015522957
+0.022618092596530914 0.00043151379213668406 0.0010069315321743488 0.010577327571809292 0.024827612563967705
+0.016392434015870094 3.267359352321364e-05 0.0009111781837418675 0.011741231195628643 0.00863039493560791
+0.018316231667995453 3.9891121559776366e-05 0.0012360874097794294 0.011881295591592789 0.009038059040904045
+0.016681168228387833 0.0005853970651514828 0.00017110827320721 0.0067966775968670845 0.014748156070709229
+0.0249526035040617 0.0008827142301015556 0.00021105010819155723 0.005309871397912502 0.022930461913347244
+0.018753357231616974 7.153774640755728e-05 0.00027049772324971855 0.004818044137209654 0.013709712773561478
+0.016498152166604996 0.0006380808772519231 0.0002917333331424743 0.007278874982148409 0.01744101569056511
+0.018261853605508804 0.00016567631973885 0.00015857718244660646 0.009908640757203102 0.00783265009522438
+0.016510048881173134 3.474821278359741e-05 9.198721090797335e-05 0.0066791120916605 0.0146149517968297
+0.0185165386646986 8.810655708657578e-05 0.0001792146940715611 0.007222854066640139 0.020297806710004807
+0.018356697633862495 0.00024066002515610307 0.0002473140775691718 0.01068157609552145 0.015501730144023895
+0.01969246193766594 3.89581655326765e-05 7.841696060495451e-05 0.012178119271993637 0.0016170904273167253
+0.01603877544403076 7.885682134656236e-05 0.0002514262159820646 0.006305169314146042 0.008785810321569443
+0.01628299616277218 0.00016438798047602177 0.0006576242740266025 0.005955031607300043 0.0033659839536994696
+0.01804092898964882 6.227642734302208e-05 0.0009908430511131883 0.008776123635470867 0.010230779647827148
+0.01659957878291607 0.0013548347633332014 0.00035587037564255297 0.007371039129793644 0.010637473315000534
+0.01182081364095211 0.0009376826928928494 0.0007924206438474357 0.0035009635612368584 0.0006590892444364727
+0.01379408035427332 0.0009495065314695239 0.0003447585622780025 0.007230805233120918 0.011048909276723862
+0.011652064509689808 0.0013685188023373485 0.00030589691596105695 0.007611805107444525 0.0029053406324237585
+0.015622463077306747 0.00023134125513024628 0.001869900501333177 0.009931494481861591 0.013663673773407936
+0.016983017325401306 0.00020208572095725685 0.0010267473990097642 0.00892663188278675 0.01854080520570278
+0.01201391126960516 0.0010576022323220968 0.0008865117561072111 0.00990681629627943 0.0030367278959602118
+0.013951286673545837 9.347032755613327e-05 0.00024319428484886885 0.005282746162265539 0.011243226937949657
+0.018814602866768837 0.0007879363838583231 0.00015201653877738863 0.0035627889446914196 0.01586329936981201
+0.017275165766477585 0.0002085007436107844 0.00028119373018853366 0.0067482590675354 0.0019524383824318647
+0.023018574342131615 4.5361248339759186e-05 0.00036384366103447974 0.005279439967125654 0.008364133536815643
+0.013861701823771 0.000139469702844508 9.81813864200376e-05 0.005268582608550787 0.004691847134381533
+0.012840020470321178 7.759006984997541e-05 9.518123260932043e-05 0.00435500917956233 0.008072633296251297
+0.012705697678029537 0.00042586616473272443 0.00020190396753605455 0.012922667898237705 0.0024133336264640093
+0.018841613084077835 0.00011939687828999013 0.0009972639381885529 0.009662069380283356 0.0044240932911634445
+0.020347703248262405 6.629236304434016e-05 0.0006201740470714867 0.005241821054369211 0.015439307317137718
+0.011347346939146519 0.001241675578057766 0.000735097099095583 0.010395459830760956 0.005449031479656696
+0.013734356500208378 4.0404716855846345e-05 6.598124309675768e-05 0.007981541566550732 0.009713942185044289
+0.015815287828445435 0.0016741915605962276 0.001286463811993599 0.004938124679028988 0.005764379166066647
+0.011864179745316505 4.815706779481843e-05 4.7977846406865865e-05 0.003356967121362686 0.007691000588238239
+0.013000423088669777 0.0012691017473116517 0.0021786505822092295 0.004524480551481247 0.008053835481405258
+0.01508798636496067 0.00012552310363389552 0.00010825279605342075 0.0039308080449700356 0.007290954235941172
+0.01701795682311058 0.0011488316813483834 0.0008092073840089142 0.006234443746507168 0.004103870131075382
+0.01322954148054123 3.9222464693011716e-05 0.00014057739463169128 0.006455042865127325 0.009597860276699066
+0.01822661980986595 0.0008540187845937908 0.001284749829210341 0.009076671674847603 0.022915992885828018
+0.013696416281163692 0.00014532121713273227 0.0001638211979297921 0.00759224034845829 0.006807580590248108
+0.020369376987218857 0.0009141123155131936 0.001427489914931357 0.008912032470107079 0.02442525327205658
+0.016069043427705765 0.00015637578326277435 7.639022805960849e-05 0.003986500669270754 0.013537873513996601
+0.012300435453653336 0.001221577520482242 0.0006326230359263718 0.0049109626561403275 9.554318967275321e-05
+0.015454317443072796 4.697971235145815e-05 0.00011170456127729267 0.004092133138328791 0.013478538952767849
+0.017798900604248047 0.0002572461380623281 0.00028721027774736285 0.005885941442102194 0.0017980131087824702
+0.013591533526778221 9.542061161482707e-05 6.19547936366871e-05 0.007879759185016155 0.004921372979879379
+0.014583989977836609 2.5843464754871093e-05 0.00017259320884477347 0.004668341018259525 0.008549523539841175
+0.013475662097334862 5.487148882821202e-05 0.0002356112381676212 0.0071343728341162205 0.0076746633276343346
+0.019231725484132767 3.87527616112493e-05 8.992729271994904e-05 0.010118285194039345 0.01936907321214676
+0.015450624749064445 0.00043751075281761587 0.0005545750027522445 0.007498983759433031 0.00741761876270175
+0.020067786797881126 1.7508729797555134e-05 0.00029908044962212443 0.004373499192297459 0.02548300474882126
+0.018500491976737976 0.0013234505895525217 0.0004120723169762641 0.0030701083596795797 0.02309746854007244
+0.011751813814043999 0.00032060977537184954 7.807482325006276e-05 0.005459039006382227 0.00022530683781951666
+0.011533337645232677 0.0006702844402752817 0.0009979434544220567 0.005514569114893675 0.0027236470486968756
+0.026098646223545074 4.065897155669518e-05 4.8616944695822895e-05 0.003843722864985466 0.0402863435447216
+0.016929924488067627 0.0007642134441994131 0.002376101678237319 0.007705444935709238 0.006231443956494331
+0.017311913892626762 0.0002265634830109775 0.00014245502825360745 0.007629793137311935 0.018265776336193085
+0.01302412897348404 0.0019252129131928086 0.0009600165649317205 0.007665739394724369 0.005080582108348608
+0.01938677206635475 4.587792864185758e-05 0.00024916097754612565 0.008472273126244545 0.012692210264503956
+0.010857699438929558 0.0011986723402515054 0.0012228115228936076 0.013479841873049736 0.003046472556889057
+0.013987241312861443 9.569479880155995e-05 0.000683233723975718 0.010705388151109219 0.010774205438792706
+0.01490216888487339 0.0005930193001404405 0.0025486908853054047 0.008541890420019627 0.01044425368309021
+0.015818549320101738 0.0007499409839510918 0.00013439770555123687 0.006425929255783558 0.013744237832725048
+0.012982651591300964 0.0015074026305228472 0.0006136853480711579 0.005682831164449453 0.0016450600232928991
+0.017812218517065048 0.0003574758302420378 0.0006679075304418802 0.005254654213786125 0.021969208493828773
+0.016046414151787758 0.0008406824199482799 0.0005637797294184566 0.006596536375582218 0.01743459515273571
+0.010835723951458931 0.00025531399296596646 0.0011031879112124443 0.003568995976820588 0.005826149135828018
+0.011666541919112206 0.0002239374298369512 0.0012029564240947366 0.0030276249162852764 0.0008972007781267166
+0.014887663535773754 0.00039201354957185686 5.1623759645735845e-05 0.004320555366575718 0.01209504809230566
+0.014998179860413074 0.00019135749607812613 0.00017824482347350568 0.008681499399244785 0.010043773800134659
+0.01725137233734131 3.0751678423257545e-05 0.00031005433993414044 0.007942063733935356 0.006102127023041248
+0.013963479548692703 7.881812052801251e-05 0.000206394906854257 0.0055618793703615665 0.0020350744016468525
+0.01552458293735981 0.00011080846888944507 6.443277379730716e-05 0.004490815103054047 0.00922235008329153
+0.014771634712815285 0.00043578105396591127 0.00018012824875768274 0.01626700721681118 0.0031242151744663715
+0.011541323736310005 0.00017986036255024374 0.0011808706913143396 0.013068375177681446 0.00399915361776948
+0.018639100715517998 0.000398290780140087 0.0023532358463853598 0.003452765289694071 0.024326054379343987
+0.01609746739268303 0.0012411783682182431 0.0002403484977548942 0.0037993155419826508 0.012979414314031601
+0.010527834296226501 0.001458513899706304 0.0008121441933326423 0.006284751463681459 0.0008076227968558669
+0.012349826283752918 0.0006516518187709153 0.0005614800611510873 0.0033670903649181128 0.007076079957187176
+0.015889620408415794 0.0011041959514841437 0.00040965733933262527 0.003201200859621167 0.016593122854828835
+0.022975027561187744 0.00029598522814922035 0.0017150609055534005 0.006257859990000725 0.033446118235588074
+0.01589629426598549 0.0001664112351136282 0.0013206040021032095 0.004641527310013771 0.00767724122852087
+0.0140891307964921 0.0009351297630928457 0.00027378727099858224 0.00426161615177989 0.0072577050887048244
+0.018576867878437042 0.00022005043865647167 0.00011656473361654207 0.006327958777546883 0.013323933817446232
+0.016760364174842834 0.0003945524513255805 0.0004577250510919839 0.00886718649417162 0.018846238031983376
+0.017928149551153183 6.119762838352472e-05 0.0002261544723296538 0.00722259096801281 0.01773219369351864
+0.01499170996248722 2.763911834335886e-05 0.0006488008657470345 0.008359240368008614 0.009380950592458248
+0.012964758090674877 0.0003430457436479628 0.00015355301729869097 0.005461674649268389 0.01155070960521698
+0.015146609395742416 0.00012218914343975484 0.00016097421757876873 0.006334258709102869 0.016089512035250664
+0.0213785357773304 0.0002481119299773127 0.0003463513858150691 0.0037062230985611677 0.01984441466629505
+0.014026197604835033 3.3244436053792015e-05 8.413616887992248e-05 0.003156057558953762 0.0020250449888408184
+0.01348950993269682 0.00012977051665075123 0.0007329752552323043 0.0048570893704891205 0.01497586164623499
+0.016669798642396927 0.00020887974824290723 8.580115536460653e-05 0.0029481430537998676 0.01943083480000496
+0.013692921958863735 0.0005611259839497507 0.0003126756055280566 0.0042907255701720715 0.009938724339008331
+0.019136961549520493 8.203073957702145e-05 0.00015257328050211072 0.004627571906894445 0.00904709193855524
+0.01821981742978096 0.0006161971250548959 0.0004422418132890016 0.0043464647606015205 0.0067219203338027
+0.010923855006694794 2.548158408899326e-05 7.183819252531976e-05 0.006654839497059584 0.00417015515267849
+0.015168578363955021 0.0002704806684050709 0.0004455797025002539 0.009983439929783344 0.012145777232944965
+0.01531794760376215 0.00011060776159865782 0.00025436843861825764 0.009191080927848816 0.005774666089564562
+0.018987108021974564 0.00010741825099103153 6.0165435570525005e-05 0.006760875694453716 0.02518102154135704
+0.014697752892971039 0.00013076596951577812 7.03091427567415e-05 0.005805240944027901 0.010033879429101944
+0.012608484365046024 6.33003655821085e-05 0.0004972786991856992 0.005445788148790598 0.00420727813616395
+0.01650189980864525 5.4430282034445554e-05 0.00011821692169178277 0.004309067036956549 0.005353847518563271
+0.013767477124929428 0.00045826067798770964 0.0004767903301399201 0.010171215049922466 0.00835310760885477
+0.0195521991699934 2.673462404345628e-05 0.0002968674525618553 0.0058441669680178165 0.013992007821798325
+0.014785714447498322 0.0004257661057636142 7.225231820484623e-05 0.006042119115591049 0.005192013923078775
+0.02120243012905121 0.0002294771111337468 0.00035110535100102425 0.008217470720410347 0.013608931563794613
+0.017941877245903015 9.697351924842224e-05 0.0003202311636414379 0.008491781540215015 0.02260197140276432
+0.013843083754181862 0.0001100180161301978 0.0002711109700612724 0.01119737047702074 0.0039583975449204445
+0.014929484575986862 6.398255936801434e-05 0.00047141261165961623 0.007533724419772625 0.006449936889111996
+0.014984508976340294 0.000301291118375957 0.00024892709916457534 0.007975799031555653 0.005585669074207544
+0.02236732468008995 0.000703248952049762 0.00018544016347732395 0.007466509472578764 0.01854868233203888
+0.01871628686785698 0.00012398850230965763 0.0005594290560111403 0.004722404759377241 0.011337537318468094
+0.015609250403940678 3.954438216169365e-05 0.0002717208117246628 0.004414363298565149 0.01154913380742073
+0.01179051585495472 3.4677217627177015e-05 0.00012114960554754362 0.004468399565666914 0.0002718196192290634
+0.021191727370023727 8.293640712508932e-05 0.0002624295884743333 0.004042205400764942 0.00755318533629179
+0.016512088477611542 8.653965051053092e-05 0.00044538892689161 0.003082080977037549 0.006958258803933859
+0.017222795635461807 0.00011296324373688549 0.0008978846599347889 0.005697747692465782 0.015576576814055443
+0.01203151885420084 0.0006154485745355487 0.0002452734624966979 0.008059749379754066 0.006945619359612465
+0.0189371220767498 0.000300166429951787 0.0003840009158011526 0.0064291865564882755 0.02235957235097885
+0.013851652853190899 0.0006551843835040927 0.0003381039132364094 0.007132433354854584 0.009886346757411957
+0.00997729878872633 0.0004799067392013967 0.00017579240375198424 0.005956238601356745 0.0004435236915014684
+0.013829074800014496 9.39928722800687e-05 0.0012675471371039748 0.006178579293191433 0.005476393271237612
+0.019568998366594315 4.998758959118277e-05 0.0003232277522329241 0.004286525305360556 0.023955492302775383
+0.01695862226188183 0.0004743776225950569 0.0003351405030116439 0.012451180256903172 0.0028508431278169155
+0.0173652283847332 0.0002788811398204416 0.0002732205612119287 0.0051427194848656654 0.013874799013137817
+0.014073144644498825 0.0002908791939262301 0.00013223457790445536 0.004007674753665924 0.00658405851572752
+0.01703285239636898 0.00032340336474590003 0.0009687317069619894 0.00916769914329052 0.011564723215997219
+0.012246714904904366 4.090317088412121e-05 0.00024092548119369894 0.006802360527217388 0.005594683811068535
+0.013027244247496128 0.0008384116226807237 0.00039668474346399307 0.008039635606110096 0.005939285270869732
+0.016356587409973145 5.988411430735141e-05 0.00010088883573189378 0.006149961147457361 0.01821933127939701
+0.015429317019879818 0.000708351843059063 0.0004182799893897027 0.007612132001668215 0.016122736036777496
+0.013260405510663986 5.465505819302052e-05 5.954768494120799e-05 0.004993009846657515 0.006194258574396372
+0.01761772483587265 0.00021255605679471046 0.0005686648073606193 0.005582810379564762 0.012415573932230473
+0.014129094779491425 5.6110740842996165e-05 4.298618296161294e-05 0.005962538067251444 0.0048589156940579414
+0.01789214089512825 0.00024293760361615568 0.00027737035998143256 0.007576241623610258 0.009447562508285046
+0.01986689493060112 4.5369306462816894e-05 4.316281047067605e-05 0.005999550223350525 0.015800679102540016
+0.01688339374959469 0.00015371681365650147 5.926505400566384e-05 0.00885731540620327 0.014382574707269669
+0.018710533156991005 3.507382280076854e-05 0.00018707686103880405 0.009201161563396454 0.014082835055887699
+0.01222559530287981 4.965611515217461e-05 0.0002866694994736463 0.005830959416925907 0.006323707289993763
+0.012768066488206387 3.167772956658155e-05 0.00012948591029271483 0.00527529139071703 0.00877581536769867
+0.014623540453612804 6.620843487326056e-05 6.734990165568888e-05 0.007017531897872686 0.002942389575764537
+0.019001029431819916 7.924575038487092e-05 0.00022314759553410113 0.005850819870829582 0.02077627368271351
+0.01629646308720112 3.916713103535585e-05 0.00015153002459555864 0.0038115931674838066 0.011399862356483936
+0.011754895560443401 0.00030378130031749606 4.477824768400751e-05 0.003170839510858059 0.0033414303325116634
+0.013561958447098732 0.0001668809709371999 3.738820305443369e-05 0.0027535047847777605 0.00994561892002821
+0.012989018112421036 8.529614569852129e-05 0.00021250557620078325 0.006900355219841003 0.003592507913708687
+0.01223605964332819 0.000308485294226557 0.00036080312565900385 0.0039718272164464 0.005852557718753815
+0.01775854080915451 7.711213402217254e-05 0.00015797697415109724 0.002733111148700118 0.008053692057728767
+0.02016916498541832 0.0005273947026580572 0.001072207116521895 0.006452658213675022 0.013160026632249355
+0.010807123966515064 0.0002698661119211465 0.00012952432734891772 0.0068076676689088345 0.007041512988507748
+0.010388963855803013 0.0010818311711773276 0.0008843166287988424 0.0051401713863015175 0.0006630902062170208
+0.01686418615281582 9.483266330789775e-05 0.00013869290705770254 0.005074081011116505 0.012973994947969913
+0.014894019812345505 0.0014755706069990993 0.0012632202124223113 0.007641762029379606 0.009880967438220978
+0.020870480686426163 4.815947249880992e-05 0.00017330065020360053 0.008799046277999878 0.023899441584944725
+0.00990554690361023 0.0010184741113334894 0.001492480281740427 0.004969757981598377 0.004027870949357748
+0.009723886847496033 7.015917071839795e-05 7.315967377508059e-05 0.006158214993774891 0.006049174815416336
+0.014164149761199951 0.0009151125559583306 0.0006941808969713748 0.005239034071564674 0.01213429681956768
+0.01455259881913662 3.0872157367412e-05 3.560883487807587e-05 0.004121691454201937 0.015911029651761055
+0.01941620372235775 0.00036033589276485145 0.00015798005915712565 0.00772978737950325 0.010889350436627865
+0.017638642340898514 4.4145337597001344e-05 8.372816228074953e-05 0.003722644643858075 0.01473008282482624
+0.012992626056075096 0.00024721171939745545 0.001057649846188724 0.004011649638414383 0.00542614096775651
+0.014552300795912743 4.7408804675797e-05 0.00027313936152495444 0.00648560281842947 0.0042943162843585014
+0.014746079221367836 0.0007510069408454001 0.0004433584399521351 0.012243976816534996 0.00931878574192524
+0.01502697542309761 0.00017308224050793797 0.0003658925124909729 0.009225914254784584 0.008758081123232841
+0.01976431906223297 0.0003706314309965819 5.8647219702834263e-05 0.006259664427489042 0.018800700083374977
+0.014025500044226646 0.000575582031160593 7.899362390162423e-05 0.0066166408360004425 0.01216903142631054
+0.01628868095576763 2.18118784687249e-05 0.0010606232099235058 0.008630369789898396 0.0073370179161429405
+0.019927091896533966 0.00010001207556342706 0.0014023327967152 0.0054413750767707825 0.027315480634570122
+0.014499695971608162 0.0007810775423422456 0.00024990353267639875 0.0036913780495524406 0.01224703062325716
+0.01614724099636078 0.0011109167244285345 0.002290289383381605 0.0068528251722455025 0.013186915777623653
+0.017427537590265274 0.0011539972620084882 4.397275552037172e-05 0.007877272553741932 0.011533062905073166
+0.013245152309536934 0.002559900050982833 9.184196824207902e-05 0.0069576152600348 0.004721296485513449
+0.015828583389520645 5.069025792181492e-05 0.002170364838093519 0.006310736294835806 0.013504241593182087
+0.012225317768752575 0.00041423970833420753 0.0022301427088677883 0.00886896438896656 0.0053353942930698395
+0.013788407668471336 0.0005390412989072502 0.0005020991084165871 0.008017589338123798 0.008878226391971111
+0.023719720542430878 0.0005541156278923154 0.0018882376607507467 0.008883709087967873 0.012582325376570225
+0.012870747596025467 0.0009337146766483784 0.00016792559472378343 0.011144918389618397 0.0027647402603179216
+0.016569245606660843 0.0013771166559308767 0.00018185999942943454 0.008722329512238503 0.008310463279485703
+0.018252341076731682 2.4830083930282854e-05 0.0018598409369587898 0.005789684131741524 0.005793415009975433
+0.014383010566234589 2.871494689316023e-05 0.0006955844582989812 0.006038574036210775 0.016423694789409637
+0.018044494092464447 0.0003196643083356321 0.0006011391524225473 0.00892604049295187 0.015032966621220112
+0.01182723231613636 6.614907033508644e-05 9.978918387787417e-05 0.004040194675326347 0.005153767764568329
+0.016177473589777946 0.0005841452511958778 4.350004383013584e-05 0.0054006376303732395 0.006685717962682247
+0.011132433079183102 5.469071038533002e-05 0.0014004616532474756 0.006543498486280441 0.004673032555729151
+0.0130226518958807 2.8754633603966795e-05 0.0010350147495046258 0.006597145926207304 0.0033981925807893276
+0.011889561079442501 0.0008003864786587656 0.00018990319222211838 0.007337396033108234 0.00255597080104053
+0.013540128245949745 6.550138641614467e-05 5.3344785555964336e-05 0.007036406081169844 0.004116587340831757
+0.018121609464287758 0.0009601506171748042 0.0005069084581919014 0.006527972873300314 0.020944397896528244
+0.011856143362820148 0.00013116392074152827 6.924616900505498e-05 0.010991514660418034 0.005137532949447632
+0.019152579829096794 0.0006544197676703334 0.0023584088776260614 0.014035315252840519 0.016450293362140656
+0.020731501281261444 0.00015396211529150605 0.0001232298818649724 0.005736753344535828 0.01677858456969261
+0.01877027191221714 0.002163655124604702 0.0017569385236129165 0.006275924388319254 0.013645832426846027
+0.018932390958070755 0.00010179710079682991 5.279696051729843e-05 0.006843919400125742 0.016863010823726654
+0.01688094437122345 0.0026958005037158728 0.001296986360102892 0.006961464416235685 0.0107340719550848
+0.013848239555954933 0.00014171698421705514 0.00011814804747700691 0.007009667344391346 0.014075195416808128
+0.009801733307540417 0.0009172675781883299 0.0014877601061016321 0.009389026090502739 0.0016845195787027478
+0.011349127627909184 5.519697515410371e-05 8.251934923464432e-05 0.006077031139284372 0.006414714269340038
+0.016522390767931938 0.00034386219340376556 0.0009398212423548102 0.007354403380304575 0.016019204631447792
+0.014221500605344772 7.82731658546254e-05 7.892674329923466e-05 0.008026868104934692 0.008099972270429134
+0.014753343537449837 0.0003090432728640735 8.334611629834399e-05 0.004138716962188482 0.004037454724311829
+0.020858407020568848 3.653004387160763e-05 0.00012729637091979384 0.004065114539116621 0.028423972427845
+0.010923734866082668 3.902191383531317e-05 0.00010197958181379363 0.0057983011938631535 0.0051287091337144375
+0.01681331917643547 9.422667790204287e-05 8.278578025056049e-05 0.007408346980810165 0.01999904401600361
+0.016433143988251686 0.00035105968709103763 0.0001742029417073354 0.005491045769304037 0.015000925399363041
+0.014071093872189522 7.818539597792551e-05 0.00019615005294326693 0.004145327024161816 0.00887729600071907
+0.011377604678273201 0.0005282111233100295 0.001039327820762992 0.0048258062452077866 0.0038392688147723675
+0.011647242121398449 9.845782187767327e-05 5.793277523480356e-05 0.007714754901826382 0.003769776551052928
+0.011549705639481544 0.000701222917996347 0.0003731836623046547 0.00631577568128705 0.002841988345608115
+0.015462186187505722 2.6793806682690047e-05 0.00013561974628828466 0.00691928481683135 0.018511241301894188
+0.011166224256157875 0.0002464650315232575 0.0002363503008382395 0.007038934156298637 0.004615339450538158
+0.015752168372273445 8.466943836538121e-05 0.000436605594586581 0.0053772456012666225 0.005233484320342541
+0.014283174648880959 4.829235695069656e-05 0.0005424648988991976 0.0031682869885116816 0.008862112648785114
+0.0106742512434721 0.00044193267240189016 9.054822294274345e-05 0.007979505695402622 0.0030195037834346294
+0.015470359474420547 8.516249363310635e-05 6.447283521993086e-05 0.008104672655463219 0.0067816548980772495
+0.010461715050041676 0.00033981038723140955 0.0009262984967790544 0.0037336237728595734 0.0029817866161465645
+0.015149664133787155 7.188769814092666e-05 5.7479126553516835e-05 0.005761431995779276 0.008923985064029694
+0.010944901965558529 0.0007773854886181653 0.0014497846132144332 0.005549645982682705 0.001454981742426753
+0.018439894542098045 0.0002746151585597545 4.805069329449907e-05 0.002657107776030898 0.01735917665064335
+0.019285039976239204 0.0006458383286371827 4.9463204049970955e-05 0.00312306871637702 0.01200784184038639
+0.018051113933324814 2.53907037404133e-05 0.0007937090122140944 0.009409574791789055 0.00582340732216835
+0.020245419815182686 4.063422602484934e-05 0.00011489192547742277 0.002980391029268503 0.03216443955898285
+0.013437904417514801 5.172221426619217e-05 0.00010680730338208377 0.0035753725096583366 0.010902142152190208
+0.013912789523601532 0.0001753841497702524 0.00020463801047299057 0.009083486162126064 0.006796820089221001
+0.012323262169957161 1.7904236301546916e-05 0.00046504000783897936 0.009107107296586037 0.004906741436570883
+0.01429906114935875 4.69117185275536e-05 0.0010981026571244001 0.003465981688350439 0.0010332900565117598
+0.014474160969257355 0.0007429446559399366 0.00025496335001662374 0.008019891567528248 0.00706095527857542
+0.019761376082897186 0.00036635593278333545 0.0003772324416786432 0.011295578442513943 0.006205372046679258
+0.012838255614042282 0.0011211353121325374 0.0010581259848549962 0.005802812986075878 0.01191666815429926
+0.010751621797680855 0.0006649712449871004 0.000547655567061156 0.005029540508985519 0.0038196027744561434
+0.016454923897981644 0.0006480514421127737 0.000818316824734211 0.004753564018756151 0.017219632863998413
+0.019167087972164154 0.0005286310915835202 0.0006137574091553688 0.005665641278028488 0.019786396995186806
+0.012261731550097466 0.0003884017060045153 0.00043851000373251736 0.00813040230423212 0.007334974128752947
+0.01949935033917427 0.00017967737221624702 0.00015588801761623472 0.0061919474974274635 0.025411654263734818
+0.01602705381810665 0.0005283365026116371 0.001013711211271584 0.004163022618740797 0.015766238793730736
+0.011434526182711124 0.00015424570301547647 0.0007123872055672109 0.0038597588427364826 0.0008746665553189814
+0.01295670960098505 0.0007266786997206509 0.00018314174667466432 0.0031909365206956863 0.00039654690772295
+0.016877111047506332 0.0002629524387884885 5.478499224409461e-05 0.005866874475032091 0.010545887053012848
+0.01611052267253399 0.00013169269368518144 0.0005580927827395499 0.01443550270050764 0.013919580727815628
+0.013505905866622925 8.426369458902627e-05 0.0003469263610895723 0.009950094856321812 0.008398184552788734
+0.014993129298090935 0.00010354380356147885 0.0011141585418954492 0.0066697606816887856 0.013097375631332397
+0.012236982583999634 0.0006726087885908782 7.870075933169574e-05 0.009147007018327713 0.004326438996940851
+0.015635846182703972 0.000570311734918505 0.0001279758580494672 0.009006389416754246 0.007398031651973724
+0.010061381384730339 6.104294152464718e-05 0.0007438011234626174 0.006106733810156584 0.0005554739036597311
+0.017728744074702263 3.216339246137068e-05 0.00011266603542026132 0.0052039045840501785 0.020369017496705055
+0.016858205199241638 0.0001487913978053257 0.0005453330813907087 0.004516024142503738 0.014443742111325264
+0.014416927471756935 7.661730342078954e-05 0.00017812574515119195 0.0027808009181171656 0.0121587123721838
+0.011714356020092964 0.00046778866089880466 4.2140269215451553e-05 0.0036209383979439735 0.0065322150476276875
+0.01021270640194416 3.922883843188174e-05 0.0005887005827389657 0.0052740746177732944 0.00013145842240191996
+0.013257509097456932 4.3280477257212624e-05 0.00046009282232262194 0.007170716300606728 0.006354278419166803
+0.01786699891090393 0.00018048497440759093 0.0001505599939264357 0.005147232208400965 0.009563854895532131
+0.01847323216497898 6.0918777307961136e-05 0.00010644431313266978 0.006269496865570545 0.014061964116990566
+0.01565277948975563 0.0003637733170762658 0.0002938594261649996 0.008855553343892097 0.016190029680728912
+0.013433591462671757 8.849333971738815e-05 0.0005598318530246615 0.01202963013201952 0.010899180546402931
+0.015906650573015213 0.00048035543295554817 0.001970307668671012 0.007219299208372831 0.009323092177510262
+0.010179584845900536 0.00043101495248265564 0.00011060458200518042 0.0029999632388353348 0.004883110523223877
+0.017919175326824188 0.0015639523044228554 0.0005363891832530499 0.004245437681674957 0.014178147539496422
+0.01243971474468708 0.00010485988605068997 0.00016961380606517196 0.004885329864919186 0.0038007821422070265
+0.01370522566139698 0.0007931434665806592 0.00045598027645610273 0.004708249121904373 0.006266572047024965
+0.01750570721924305 8.042690751608461e-05 0.0009303479455411434 0.0037784900050610304 0.004959472920745611
+0.010619400069117546 0.0002513649233151227 0.0021131953690201044 0.002207736251875758 0.00046262898831628263
+0.024328352883458138 0.0005454221391119063 0.0002053114294540137 0.0050182766281068325 0.01852322183549404
+0.01570533588528633 0.0013758995337411761 0.0004477874026633799 0.007496607024222612 0.014137424528598785
+0.009991088882088661 3.953351915697567e-05 0.000389408873161301 0.010560674592852592 0.004685159772634506
+0.011006608605384827 0.0003388413751963526 0.00021402201673481613 0.005710902623832226 0.008260917849838734
+0.02206142619252205 0.00046481547178700566 0.0006096327560953796 0.0042419699020683765 0.03531516343355179
+0.012901533395051956 0.00042983642197214067 0.0005187882343307137 0.00738375261425972 0.010082019492983818
+0.014075339771807194 4.661046841647476e-05 0.0006650086143054068 0.004309803247451782 0.007789428345859051
+0.020592905580997467 4.774153421749361e-05 7.714280218351632e-05 0.002448658924549818 0.02302936278283596
+0.01414693333208561 0.00019993535534013063 0.0012250145664438605 0.004505510907620192 0.0025010844692587852
+0.01803237572312355 0.0005493867793120444 0.0003488028305582702 0.005654498934745789 0.02052556537091732
+0.012474440038204193 0.0020462521351873875 0.00012096659338567406 0.005676133092492819 0.003793303621932864
+0.015944477170705795 7.54542343202047e-05 0.0004851376870647073 0.00628148065879941 0.01567160151898861
+0.0211184099316597 0.0006788162281736732 0.00140468031167984 0.006276588421314955 0.03006233647465706
+0.021900122985243797 2.6170342607656494e-05 5.311066342983395e-05 0.011222673580050468 0.017179353162646294
+0.017276566475629807 0.0006005032337270677 0.0022259652614593506 0.007828719913959503 0.012718587182462215
+0.01361753698438406 0.00019170658197253942 0.0003943425545003265 0.006163164507597685 0.00014651002129539847
+0.013730017468333244 0.0022328600753098726 0.0004205198201816529 0.00548466807231307 0.0053575714118778706
+0.014800680801272392 8.765207894612104e-05 6.42795639578253e-05 0.006460511591285467 0.01188313402235508
+0.017616109922528267 0.001419393694959581 0.0014015065971761942 0.005631537176668644 0.022964082658290863
+0.012311751022934914 0.00024289345310535282 7.934199675219133e-05 0.0077488296665251255 0.008575621992349625
+0.011836658231914043 0.0004330370866227895 0.0008382465457543731 0.0039695873856544495 0.007951981388032436
+0.01934032328426838 0.0002529093180783093 0.000652247981633991 0.004942505154758692 0.011546083725988865
+0.010788623243570328 0.0001451102434657514 0.0001304155302932486 0.005937054753303528 0.0038295993581414223
+0.02619912475347519 0.0003518722951412201 0.00016021390911191702 0.02042582258582115 0.02102934755384922
+0.013992026448249817 8.855002670316026e-05 0.00023022503592073917 0.005261963699012995 0.01400725357234478
+0.018462909385561943 6.835504609625787e-05 0.00035916053457185626 0.00513625331223011 0.009550312533974648
+0.017221778631210327 3.0978604627307504e-05 0.00023198856797534972 0.009578151628375053 0.014341045171022415
+0.01385616697371006 0.0001922614173963666 0.001209461479447782 0.00628281943500042 0.002042471431195736
+0.01597641408443451 0.0001053742817020975 0.0006148045649752021 0.004156249575316906 0.011030839756131172
+0.011568473652005196 0.0013933397131040692 9.128588862949982e-05 0.004970266483724117 0.002990493318066001
+0.023110058158636093 0.000339700112817809 5.1337665354367346e-05 0.005530609283596277 0.0261213518679142
+0.012825394980609417 0.00039911523344926536 0.00041571431211195886 0.007602420169860125 0.0029035620391368866
+0.019752152264118195 0.0002459780953358859 3.2913059840211645e-05 0.004287668038159609 0.011101488023996353
+0.014148172922432423 0.00029430375434458256 0.0022368726786226034 0.004797694738954306 0.008321243338286877
+0.02276691235601902 6.904853216838092e-05 0.0010693040676414967 0.006577655207365751 0.013796856626868248
+0.00900082103908062 0.0020744148641824722 0.0004009941185358912 0.005855425260961056 0.000494154985062778
+0.012376341968774796 0.00027193292044103146 0.00011394050670787692 0.007973385974764824 0.010312797501683235
+0.013197094202041626 0.0012962616747245193 0.0007292803493328393 0.009469354525208473 0.009344031102955341
+0.011331879533827305 0.0003739345702342689 9.597904136171564e-05 0.004988856613636017 0.0007743852911517024
+0.017134910449385643 0.00024199558538384736 0.002458900911733508 0.005422762595117092 0.01837109588086605
diff --git a/model.pth b/model.pth
index 4cd979c..c4b12cd 100644
--- a/model.pth
+++ b/model.pth
Binary files differ
diff --git a/test-2d-histogram.png b/test-2d-histogram.png
deleted file mode 100644
index 75f2f6e..0000000
--- a/test-2d-histogram.png
+++ /dev/null
Binary files differ
diff --git a/test.dot b/test.dot
new file mode 100644
index 0000000..664b30e
--- /dev/null
+++ b/test.dot
@@ -0,0 +1,18 @@
+graph G {
+ 1 [color=blue];
+ 2 [color=red];
+ 3 [color=red];
+ 5 [color=green];
+ 11 [color=red];
+ 1 -- 3 [color=red];
+ 4 -- 3;
+ 11 -- 2 [color=red];
+ 9 -- 3;
+ 6 -- 7
+ 4 -- 5;
+ 10 -- 8;
+ 5 -- 3 [color=red];
+ 3 -- 8;
+ 5 -- 11 [color=red];
+ 4 -- 9;
+}
diff --git a/train.dot b/train.dot
new file mode 100644
index 0000000..48d7d46
--- /dev/null
+++ b/train.dot
@@ -0,0 +1,18 @@
+graph G {
+ 1 [color=blue];
+ 2 [color=green];
+ 3 [color=red];
+ 5 [color=red];
+ 11 [color=red];
+ 1 -- 3 [color=red];
+ 4 -- 3;
+ 11 -- 2 [color=red];
+ 9 -- 3;
+ 6 -- 7
+ 4 -- 5;
+ 10 -- 8;
+ 5 -- 3 [color=red];
+ 3 -- 8;
+ 5 -- 11 [color=red];
+ 4 -- 9;
+}
diff --git a/training-2d-histogram.png b/training-2d-histogram.png
deleted file mode 100644
index 31c3340..0000000
--- a/training-2d-histogram.png
+++ /dev/null
Binary files differ
diff --git a/training-loss.png b/training-loss.png
deleted file mode 100644
index 9cdc225..0000000
--- a/training-loss.png
+++ /dev/null
Binary files differ
diff --git a/transformer_shortest_paths.ipynb b/transformer_shortest_paths.ipynb
index c2164ac..8a99b4b 100644
--- a/transformer_shortest_paths.ipynb
+++ b/transformer_shortest_paths.ipynb
@@ -230,7 +230,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"execution_state": "idle",
"metadata": {},
"outputs": [
@@ -270,7 +270,7 @@
" False]]))"
]
},
- "execution_count": 4,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -281,7 +281,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 6,
"execution_state": "idle",
"metadata": {},
"outputs": [
@@ -308,7 +308,7 @@
" <BarContainer object of 64 artists>)"
]
},
- "execution_count": 5,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
@@ -329,7 +329,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
"execution_state": "idle",
"metadata": {},
"outputs": [
@@ -356,7 +356,7 @@
" <BarContainer object of 64 artists>)"
]
},
- "execution_count": 6,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
@@ -386,7 +386,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 8,
"execution_state": "idle",
"metadata": {
"id": "tLOWhg_CeWzH"
@@ -423,7 +423,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 9,
"execution_state": "idle",
"metadata": {
"colab": {
@@ -438,7 +438,7 @@
"output_type": "stream",
"text": [
"Training data: 1048576K\n",
- "Trainable parameters in the model: 500449\n"
+ "Trainable parameters in the model: 550433\n"
]
}
],
@@ -449,7 +449,7 @@
"NEPOCHS = 1000\n",
"BSZ = 2**15 # Batch size\n",
"NHEADS = 2\n",
- "NLAYERS = 10\n",
+ "NLAYERS = 11\n",
"DROPOUT = 0 # 0.2\n",
"model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
" output_dim=1, num_heads=NHEADS,\n",
@@ -474,7 +474,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 10,
"execution_state": "idle",
"metadata": {},
"outputs": [],
@@ -486,7 +486,7 @@
"# model = torch.compile(model)\n",
"# model.load_state_dict(torch.load('model.pth', weights_only=True))\n",
"\n",
- "LR = 5e-4\n",
+ "LR = 3e-4\n",
"\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=LR)"
@@ -504,6 +504,7 @@
{
"cell_type": "code",
"execution_count": null,
+ "execution_state": "running",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@@ -512,7 +513,3576 @@
"id": "pvTfzGmCeXU4",
"outputId": "0d3a20f3-23be-4c19-9eb6-46bfe11a48b1"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/sipb/.venv/lib64/python3.12/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)\n",
+ " attn_output = scaled_dot_product_attention(\n",
+ "/home/sipb/.venv/lib64/python3.12/site-packages/torch/_inductor/compile_fx.py:167: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
+ " warnings.warn(\n",
+ "/tmp/torchinductor_sipb/lc/clcqc3ufbzrethiy77dmsu54kurxdmh4eji2f3msm347rhmfpf4j.py:1078: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)\n",
+ " extern_kernels.mm(reinterpret_tensor(buf1, (524288, 64), (64, 1), 0), reinterpret_tensor(primals_5, (64, 192), (1, 64), 0), out=buf2)\n",
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:53<00:00, 1.68s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 0/1000 \t Train Err: 44.25287628173828 15.68527889251709 8.788555145263672 3.8603785037994385 100.71075439453125\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/1000 \t Train Err: 42.749107360839844 19.990596771240234 12.096658706665039 6.1431379318237305 90.65858459472656\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 2/1000 \t Train Err: 35.17692565917969 0.05676012113690376 15.442191123962402 8.584199905395508 82.2193374633789\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 3/1000 \t Train Err: 32.94102478027344 0.0016151332529261708 21.24653434753418 13.028000831604004 70.40452575683594\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4/1000 \t Train Err: 30.707416534423828 0.0012646751711145043 27.01055335998535 17.616626739501953 60.88502502441406\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 5/1000 \t Train Err: 29.5562801361084 0.0010397899895906448 32.37065124511719 21.991931915283203 53.443416595458984\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 6/1000 \t Train Err: 29.092647552490234 0.0007909027044661343 36.97315979003906 25.812442779541016 47.87839126586914\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 7/1000 \t Train Err: 28.649932861328125 0.0005968677578493953 40.76370620727539 28.99480438232422 43.76168441772461\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 8/1000 \t Train Err: 28.49782943725586 0.00046271493192762136 43.50001907348633 31.30975914001465 41.015228271484375\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 9/1000 \t Train Err: 28.43338394165039 0.00036097457632422447 45.3854866027832 32.91228103637695 39.220703125\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 10/1000 \t Train Err: 28.3504581451416 0.0002947713655885309 46.44082260131836 33.81328201293945 38.23492431640625\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 11/1000 \t Train Err: 28.3637638092041 0.00025288635515607893 44.81169509887695 32.499359130859375 39.33035659790039\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 12/1000 \t Train Err: 20.932743072509766 0.020399289205670357 9.31131649017334 9.292236328125 46.657718658447266\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 13/1000 \t Train Err: 15.104679107666016 0.004516741260886192 3.496262311935425 26.43575096130371 27.609149932861328\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 14/1000 \t Train Err: 11.633755683898926 0.006426772102713585 0.33657196164131165 12.253266334533691 24.989166259765625\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 15/1000 \t Train Err: 8.730537414550781 0.005749554373323917 0.4862631857395172 14.309078216552734 15.44126033782959\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 16/1000 \t Train Err: 7.443472385406494 0.006111220922321081 0.20445628464221954 5.455012321472168 16.42494773864746\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 17/1000 \t Train Err: 5.684250354766846 0.0027684608940035105 0.17002306878566742 7.258749961853027 9.979669570922852\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 18/1000 \t Train Err: 4.853455543518066 0.002522163325920701 0.1455077826976776 5.698103427886963 8.302152633666992\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 19/1000 \t Train Err: 4.262304306030273 0.008678397163748741 0.1446814388036728 6.152136325836182 6.787021636962891\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 20/1000 \t Train Err: 3.858119487762451 0.003093192121013999 0.06649590283632278 3.3348793983459473 7.0024285316467285\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 21/1000 \t Train Err: 3.4118576049804688 0.0032999212853610516 0.08424855768680573 4.119073390960693 5.404160976409912\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 22/1000 \t Train Err: 3.108879566192627 0.0038715973496437073 0.09217671304941177 3.321317434310913 4.763245582580566\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 23/1000 \t Train Err: 2.9745163917541504 0.0034261371474713087 0.042425476014614105 2.484790802001953 5.148287296295166\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 24/1000 \t Train Err: 2.8194165229797363 0.0035489683505147696 0.04434378817677498 2.200214385986328 4.857565879821777\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 25/1000 \t Train Err: 2.570242404937744 0.0033989306539297104 0.05114512890577316 1.687565565109253 4.485219478607178\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 26/1000 \t Train Err: 2.519033432006836 0.0032220594584941864 0.037310708314180374 1.374830961227417 4.5875935554504395\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 27/1000 \t Train Err: 2.4283571243286133 0.0033829077146947384 0.05773899331688881 1.9254313707351685 3.3634419441223145\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 28/1000 \t Train Err: 2.302886962890625 0.0026804266963154078 0.03176456317305565 1.1421442031860352 4.208995819091797\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 29/1000 \t Train Err: 2.1564438343048096 0.0029864327516406775 0.05358732491731644 1.2556946277618408 3.505455255508423\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 30/1000 \t Train Err: 2.1266884803771973 0.0027816162910312414 0.0484829805791378 1.4444752931594849 2.9733316898345947\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 31/1000 \t Train Err: 2.0390195846557617 0.0022390384692698717 0.03866223618388176 1.1827622652053833 2.866259813308716\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 32/1000 \t Train Err: 1.9310557842254639 0.007157730869948864 0.04426664113998413 1.3611301183700562 2.6592557430267334\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 33/1000 \t Train Err: 1.851065754890442 0.00836748257279396 0.02673536352813244 0.9250126481056213 2.746307373046875\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 34/1000 \t Train Err: 1.8189555406570435 0.002403518883511424 0.027730008587241173 0.7969512939453125 3.080704689025879\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 35/1000 \t Train Err: 1.7624272108078003 0.0018843680154532194 0.0399935357272625 1.2561802864074707 2.1936941146850586\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 36/1000 \t Train Err: 1.6249182224273682 0.004734295420348644 0.031796831637620926 0.9225327372550964 2.0132899284362793\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 37/1000 \t Train Err: 1.567882776260376 0.0021626888774335384 0.023232363164424896 0.6002436876296997 2.547755718231201\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 38/1000 \t Train Err: 1.5313199758529663 0.00217650830745697 0.026524055749177933 0.7098141312599182 2.2419848442077637\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 39/1000 \t Train Err: 1.4404828548431396 0.004386692773550749 0.02224084362387657 0.4259363114833832 2.4182753562927246\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 40/1000 \t Train Err: 1.2733204364776611 0.0018442481523379683 0.018874434754252434 0.42629194259643555 2.229574680328369\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 41/1000 \t Train Err: 1.259386658668518 0.003313510213047266 0.029421737417578697 0.5856508612632751 1.8896589279174805\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 42/1000 \t Train Err: 1.2521576881408691 0.004783857148140669 0.01938166655600071 0.3461814820766449 2.3111233711242676\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 43/1000 \t Train Err: 1.2292133569717407 0.002071063034236431 0.021942850202322006 0.4656696617603302 2.068932056427002\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 44/1000 \t Train Err: 1.128746747970581 0.003524728585034609 0.02287750504910946 0.45642775297164917 1.7306324243545532\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 45/1000 \t Train Err: 1.0967061519622803 0.01454948354512453 0.024551132693886757 0.54127436876297 1.5558075904846191\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 46/1000 \t Train Err: 1.0431199073791504 0.012847564183175564 0.04218384996056557 0.6190078258514404 1.189369559288025\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 47/1000 \t Train Err: 1.0133185386657715 0.0021092321258038282 0.013825979083776474 0.2770753800868988 1.6761137247085571\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 48/1000 \t Train Err: 0.9567054510116577 0.004586610943078995 0.014754511415958405 0.3456564247608185 1.453568696975708\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 49/1000 \t Train Err: 0.9407138228416443 0.0070776138454675674 0.02054545097053051 0.4561208486557007 1.220456838607788\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 50/1000 \t Train Err: 0.9202176332473755 0.0031614142935723066 0.011641060933470726 0.286685585975647 1.5036773681640625\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 51/1000 \t Train Err: 0.8502955436706543 0.004485825542360544 0.014037835411727428 0.261470228433609 1.4015151262283325\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 52/1000 \t Train Err: 0.8055517673492432 0.001586323487572372 0.01207085233181715 0.25895756483078003 1.2997738122940063\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 53/1000 \t Train Err: 0.745665431022644 0.0011599217541515827 0.011456611566245556 0.3580838143825531 1.1415109634399414\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 54/1000 \t Train Err: 0.7061874866485596 0.0031232465989887714 0.011377043090760708 0.3388000428676605 0.9392197728157043\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 55/1000 \t Train Err: 0.6648885607719421 0.0012688999995589256 0.00947241485118866 0.26360955834388733 1.1422845125198364\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 56/1000 \t Train Err: 0.7076109051704407 0.0024704698007553816 0.009398349560797215 0.3038281798362732 1.1086375713348389\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 57/1000 \t Train Err: 0.6785869598388672 0.0014807500410825014 0.01062911655753851 0.2191060185432434 1.1564222574234009\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 58/1000 \t Train Err: 0.6877794861793518 0.0014096886152401567 0.009400570765137672 0.30001187324523926 1.0206444263458252\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 59/1000 \t Train Err: 0.6211211681365967 0.0020437357015907764 0.007910181768238544 0.25007426738739014 0.9663841724395752\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 60/1000 \t Train Err: 0.6470574736595154 0.002864431357011199 0.009593289345502853 0.2109789252281189 1.1501246690750122\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 61/1000 \t Train Err: 0.5603711605072021 0.0008608726784586906 0.007078849244862795 0.20478710532188416 0.8139731884002686\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 62/1000 \t Train Err: 0.578612208366394 0.0032247111666947603 0.00902671180665493 0.42988908290863037 0.7487800717353821\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 63/1000 \t Train Err: 0.5186704397201538 0.00102539814542979 0.008104916661977768 0.2222120761871338 0.7046389579772949\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 64/1000 \t Train Err: 0.5231799483299255 0.002828507451340556 0.010938749648630619 0.2966611087322235 0.7127745747566223\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 65/1000 \t Train Err: 0.47884172201156616 0.0009621918434277177 0.006440795958042145 0.18400059640407562 0.704879105091095\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 66/1000 \t Train Err: 0.5025489926338196 0.0009218254708684981 0.010015169158577919 0.31187739968299866 0.6326962113380432\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 67/1000 \t Train Err: 0.43556588888168335 0.0009745100978761911 0.004804776981472969 0.18548405170440674 0.7211376428604126\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 68/1000 \t Train Err: 0.4796810746192932 0.0011300166370347142 0.00498929712921381 0.19455471634864807 0.7402910590171814\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 69/1000 \t Train Err: 0.4702203869819641 0.0010465772356837988 0.00821524765342474 0.19802263379096985 0.6487798094749451\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 70/1000 \t Train Err: 0.4697670042514801 0.0023691540118306875 0.01255676057189703 0.12256406992673874 0.7824207544326782\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 71/1000 \t Train Err: 0.45800718665122986 0.000834575854241848 0.006560661364346743 0.31839612126350403 0.5303618311882019\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 72/1000 \t Train Err: 0.4298030138015747 0.0008242715266533196 0.004054277669638395 0.12216580659151077 0.7027520537376404\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 73/1000 \t Train Err: 0.41886794567108154 0.003492131596431136 0.006118918769061565 0.13266974687576294 0.5819666981697083\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 74/1000 \t Train Err: 0.4155438542366028 0.005384758580476046 0.004661516286432743 0.18202976882457733 0.46242862939834595\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 75/1000 \t Train Err: 0.4029538929462433 0.0027740243822336197 0.004722914192825556 0.11864388734102249 0.6423448920249939\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 76/1000 \t Train Err: 0.40730276703834534 0.0007228117901831865 0.004857086110860109 0.19555920362472534 0.5367923974990845\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 77/1000 \t Train Err: 0.36605364084243774 0.0009539870661683381 0.004075151868164539 0.07661056518554688 0.6930822730064392\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 78/1000 \t Train Err: 0.35346490144729614 0.0008490128675475717 0.011079245246946812 0.2623752951622009 0.35006481409072876\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 79/1000 \t Train Err: 0.36204081773757935 0.0009433329687453806 0.00526107894256711 0.21243207156658173 0.48520198464393616\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 80/1000 \t Train Err: 0.37033170461654663 0.007113362662494183 0.025272028520703316 0.1579497754573822 0.509062647819519\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 81/1000 \t Train Err: 0.33435767889022827 0.0053968350403010845 0.009334199130535126 0.08819041401147842 0.5340674519538879\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 82/1000 \t Train Err: 0.33757588267326355 0.0005324012017808855 0.004930454771965742 0.17805498838424683 0.5068127512931824\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 83/1000 \t Train Err: 0.31310147047042847 0.0008425716659985483 0.002387900371104479 0.0860978364944458 0.5014434456825256\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 84/1000 \t Train Err: 0.2939777672290802 0.0005211451789364219 0.0028972462750971317 0.09632651507854462 0.32632964849472046\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 85/1000 \t Train Err: 0.299526184797287 0.0010479650227352977 0.0028079829644411802 0.21169798076152802 0.40237608551979065\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 86/1000 \t Train Err: 0.30821743607521057 0.0029116582591086626 0.0035962064284831285 0.12347913533449173 0.4878503084182739\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 87/1000 \t Train Err: 0.30215075612068176 0.0008221607422456145 0.002995165530592203 0.14785777032375336 0.37114524841308594\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 88/1000 \t Train Err: 0.28064608573913574 0.0016746899345889688 0.0033167945221066475 0.12754225730895996 0.36404597759246826\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 89/1000 \t Train Err: 0.3496437072753906 0.00890024658292532 0.004831824917346239 0.15963110327720642 0.61163729429245\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 90/1000 \t Train Err: 0.25926288962364197 0.0013453153660520911 0.003530274610966444 0.1070982962846756 0.3172129690647125\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 91/1000 \t Train Err: 0.27578499913215637 0.0030208230018615723 0.0031245634891092777 0.11023195832967758 0.3903992176055908\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 92/1000 \t Train Err: 0.2994685769081116 0.004197238944470882 0.0023698117583990097 0.1903829723596573 0.4574669599533081\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 93/1000 \t Train Err: 0.2922815680503845 0.0007029320695437491 0.002348515437915921 0.05019014701247215 0.5161943435668945\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 94/1000 \t Train Err: 0.2519854009151459 0.0017584029119461775 0.002146892249584198 0.09622428566217422 0.3795848488807678\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 95/1000 \t Train Err: 0.2682843506336212 0.0005638344446197152 0.002339219441637397 0.10693421214818954 0.41891878843307495\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 96/1000 \t Train Err: 0.23551344871520996 0.0013522340450435877 0.002122987760230899 0.07678966224193573 0.39242544770240784\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 97/1000 \t Train Err: 0.22959834337234497 0.0011262426851317286 0.005589182022958994 0.09626860916614532 0.3728175759315491\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 98/1000 \t Train Err: 0.25520408153533936 0.00047998130321502686 0.0021106365602463484 0.06091078370809555 0.46852487325668335\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 99/1000 \t Train Err: 0.2501460909843445 0.0004842414637096226 0.0019760928116738796 0.08570502698421478 0.2884354889392853\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 100/1000 \t Train Err: 0.20366686582565308 0.002971487818285823 0.00713342847302556 0.11135696619749069 0.31001028418540955\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 101/1000 \t Train Err: 0.24212954938411713 0.002736428752541542 0.006694257725030184 0.07518018037080765 0.4043455421924591\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 102/1000 \t Train Err: 0.21833863854408264 0.0023800276685506105 0.0043370905332267284 0.06609099358320236 0.32194095849990845\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 103/1000 \t Train Err: 0.17334935069084167 0.0016206795116886497 0.002039699349552393 0.07832205295562744 0.2227916121482849\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 104/1000 \t Train Err: 0.22861380875110626 0.00032782970811240375 0.0014689437812194228 0.04350364953279495 0.3887456953525543\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 105/1000 \t Train Err: 0.19524142146110535 0.0035212738439440727 0.006021957378834486 0.11795622110366821 0.25324639678001404\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 106/1000 \t Train Err: 0.15101337432861328 0.0010637118248268962 0.006495238747447729 0.07132977992296219 0.20919014513492584\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 107/1000 \t Train Err: 0.16016799211502075 0.0004500410286709666 0.0017898066435009241 0.05250254645943642 0.23433028161525726\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 108/1000 \t Train Err: 0.16674524545669556 0.0010276080574840307 0.004126560874283314 0.11649245023727417 0.1926332712173462\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 109/1000 \t Train Err: 0.1631726324558258 0.000415267248172313 0.0019610195886343718 0.047292474657297134 0.2021452933549881\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 110/1000 \t Train Err: 0.15775498747825623 0.000414436828577891 0.0015739103546366096 0.09467942267656326 0.19835522770881653\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 111/1000 \t Train Err: 0.1562364548444748 0.0061783152632415295 0.0036767187993973494 0.05223577097058296 0.2691049575805664\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 112/1000 \t Train Err: 0.12982504069805145 0.006277484819293022 0.0025301864370703697 0.0412493497133255 0.1602799892425537\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 113/1000 \t Train Err: 0.16289404034614563 0.0005659068119712174 0.0012484622420743108 0.056075569242239 0.21527709066867828\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 114/1000 \t Train Err: 0.1465548872947693 0.0014426918933168054 0.0027563830371946096 0.05711502209305763 0.18701669573783875\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 115/1000 \t Train Err: 0.14275527000427246 0.000653824012260884 0.0014178442070260644 0.0954805463552475 0.21085317432880402\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 116/1000 \t Train Err: 0.15016044676303864 0.000818660540971905 0.0016999151557683945 0.06580580770969391 0.1766045093536377\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 117/1000 \t Train Err: 0.12201909720897675 0.0015263869427144527 0.0018288405844941735 0.041804831475019455 0.1663890928030014\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 118/1000 \t Train Err: 0.12489122897386551 0.00029311463003978133 0.001244824263267219 0.025852475315332413 0.14584743976593018\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 119/1000 \t Train Err: 0.13460129499435425 0.010341204702854156 0.006242326460778713 0.04511499032378197 0.2327692061662674\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 120/1000 \t Train Err: 0.1339903473854065 0.0011800267966464162 0.0022739972919225693 0.032896753400564194 0.15254607796669006\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 121/1000 \t Train Err: 0.12022703886032104 0.00038554586353711784 0.004875303246080875 0.06586042791604996 0.1974797397851944\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 122/1000 \t Train Err: 0.11403584480285645 0.0008288199896924198 0.0015265692491084337 0.02858065813779831 0.17847180366516113\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 123/1000 \t Train Err: 0.10409620404243469 0.0009630229906179011 0.0013286847388371825 0.07061317563056946 0.12211479991674423\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 124/1000 \t Train Err: 0.09401391446590424 0.0007407785742543638 0.001438300940208137 0.04962582886219025 0.11180052161216736\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 125/1000 \t Train Err: 0.11572305113077164 0.0022304926533252 0.0015191843267530203 0.044924236834049225 0.16141481697559357\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 126/1000 \t Train Err: 0.09036736190319061 0.0019903378561139107 0.000955000170506537 0.041499681770801544 0.07043104618787766\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 127/1000 \t Train Err: 0.10687348991632462 0.0012028561905026436 0.0007425333024002612 0.03638176992535591 0.15492156147956848\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 128/1000 \t Train Err: 0.10416512191295624 0.00024147499061655253 0.0009293883922509849 0.028597401455044746 0.18022169172763824\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 129/1000 \t Train Err: 0.08835020661354065 0.003051575506106019 0.003711810801178217 0.05737682431936264 0.120809406042099\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 130/1000 \t Train Err: 0.09325292706489563 0.0029582607094198465 0.0012874709209427238 0.048724617809057236 0.09893109649419785\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 131/1000 \t Train Err: 0.0896429568529129 0.00045463963760994375 0.0013198753586038947 0.035679738968610764 0.10701773315668106\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 132/1000 \t Train Err: 0.08471657335758209 0.0019452492706477642 0.0034233005717396736 0.0296705961227417 0.06480914354324341\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 133/1000 \t Train Err: 0.09226851165294647 0.0015877934638410807 0.002174847759306431 0.02515079453587532 0.13485927879810333\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 134/1000 \t Train Err: 0.08039496093988419 0.0059851668775081635 0.003604266792535782 0.023677635937929153 0.10068632662296295\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 135/1000 \t Train Err: 0.07723499834537506 0.0029561370611190796 0.0013383673503994942 0.02225509285926819 0.07217926532030106\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 136/1000 \t Train Err: 0.08774624764919281 0.0015881237341091037 0.0009627199615351856 0.03728866204619408 0.1386537104845047\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 137/1000 \t Train Err: 0.09112915396690369 0.0011313623981550336 0.000519853550940752 0.02461346611380577 0.09449820965528488\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 138/1000 \t Train Err: 0.0764939934015274 0.0015739547088742256 0.002696078270673752 0.02539546601474285 0.07497072219848633\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 139/1000 \t Train Err: 0.08085883408784866 0.0021111860405653715 0.002203426556661725 0.038074932992458344 0.09940771758556366\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 140/1000 \t Train Err: 0.06707267463207245 0.00048505759332329035 0.0005829626461490989 0.023159010335803032 0.06615001708269119\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 141/1000 \t Train Err: 0.06301475316286087 0.0009192335419356823 0.0005843854160048068 0.03602571412920952 0.0841449722647667\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 142/1000 \t Train Err: 0.06407696008682251 0.004187325481325388 0.0031896685250103474 0.025413284078240395 0.0635743960738182\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 143/1000 \t Train Err: 0.056899648159742355 0.0012612127466127276 0.0010306687327101827 0.021679643541574478 0.05377352237701416\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 144/1000 \t Train Err: 0.07650666683912277 0.00024983700131997466 0.001161872292868793 0.03564540296792984 0.08232379704713821\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 145/1000 \t Train Err: 0.05720144882798195 0.0014120059786364436 0.0017668631626293063 0.03595374524593353 0.06960506737232208\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 146/1000 \t Train Err: 0.0696098804473877 0.00025453948182985187 0.0005893263150937855 0.018764346837997437 0.03399818018078804\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 147/1000 \t Train Err: 0.06046349182724953 0.000990899046882987 0.0004648214380722493 0.02092691697180271 0.02779306285083294\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 148/1000 \t Train Err: 0.059290580451488495 0.0022199612576514482 0.001315723406150937 0.023515846580266953 0.07406013458967209\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 149/1000 \t Train Err: 0.06473296880722046 0.0018222469370812178 0.001621031784452498 0.03560686111450195 0.06342668831348419\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 150/1000 \t Train Err: 0.053641099482774734 0.0003684388066176325 0.0008896152721717954 0.026580102741718292 0.07826311886310577\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 151/1000 \t Train Err: 0.05356581509113312 7.923251541797072e-05 0.00026825847453437746 0.018474940210580826 0.06338636577129364\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 152/1000 \t Train Err: 0.053509317338466644 0.000861925829667598 0.0006878876592963934 0.03851037099957466 0.045843761414289474\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 153/1000 \t Train Err: 0.053101420402526855 9.4662478659302e-05 0.0003207987465430051 0.020021196454763412 0.06933460384607315\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 154/1000 \t Train Err: 0.04212021827697754 0.00016945773677434772 0.0002675392315723002 0.01669382117688656 0.040745120495557785\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 155/1000 \t Train Err: 0.038299448788166046 0.0006350624025799334 0.0006876391125842929 0.021383486688137054 0.03851837292313576\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 156/1000 \t Train Err: 0.05666494742035866 0.0003512321272864938 0.00041459291242063046 0.02086237445473671 0.0636894553899765\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 157/1000 \t Train Err: 0.03654496371746063 7.953925523906946e-05 0.00024021716671995819 0.016982462257146835 0.03953758254647255\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 158/1000 \t Train Err: 0.04525972530245781 0.00030174560379236937 0.00044638325925916433 0.011691144667565823 0.041540782898664474\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 159/1000 \t Train Err: 0.039281442761421204 0.0004989005392417312 0.0007457975880242884 0.02650555409491062 0.03530228137969971\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 160/1000 \t Train Err: 0.04518686234951019 0.00017189154459629208 0.00015555098070763052 0.01069786585867405 0.0570724681019783\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 161/1000 \t Train Err: 0.050255514681339264 0.00024805148132145405 0.0007079236092977226 0.02650279738008976 0.06579962372779846\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 162/1000 \t Train Err: 0.044423386454582214 0.00031116121681407094 0.0003777859383262694 0.017297720536589622 0.051457729190588\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 163/1000 \t Train Err: 0.04354090988636017 0.0005877528456039727 0.00029637134866788983 0.013659253716468811 0.020231008529663086\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 164/1000 \t Train Err: 0.031157420948147774 9.798325481824577e-05 0.0003091402177233249 0.01797928847372532 0.023115914314985275\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 165/1000 \t Train Err: 0.04593996703624725 8.030498429434374e-05 0.0002926443121396005 0.013717164285480976 0.055002935230731964\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 166/1000 \t Train Err: 0.03411879763007164 0.0003377409011591226 0.0002743960649240762 0.01241423562169075 0.01924932189285755\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 167/1000 \t Train Err: 0.04760969430208206 0.00011350985732860863 0.00043987829121761024 0.014360723085701466 0.02648620493710041\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 168/1000 \t Train Err: 0.03372378274798393 0.0002949015761259943 0.00037480168975889683 0.014883959665894508 0.016884008422493935\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 169/1000 \t Train Err: 0.03702085092663765 0.00018087019270751625 0.0006288738804869354 0.024845065549016 0.04383993148803711\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 170/1000 \t Train Err: 0.02824767678976059 0.0003424994938541204 0.0004191694315522909 0.015994146466255188 0.014700128696858883\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 171/1000 \t Train Err: 0.03430527076125145 0.0009072513785213232 0.0003185459354426712 0.015698766335844994 0.011771176941692829\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 172/1000 \t Train Err: 0.03365296125411987 0.00030303309904411435 0.00022830757370684296 0.017944660037755966 0.02212398685514927\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 173/1000 \t Train Err: 0.041653022170066833 0.000587280432227999 0.0007374075939878821 0.015334094874560833 0.06751907616853714\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 174/1000 \t Train Err: 0.03477595001459122 0.0009183208458125591 0.00013558704813476652 0.013238264247775078 0.01334806066006422\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 175/1000 \t Train Err: 0.037530943751335144 0.00025187639403156936 0.00033183913910761476 0.01504728477448225 0.031140420585870743\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 176/1000 \t Train Err: 0.027848348021507263 0.0014607060002163053 0.001157138729467988 0.011170606128871441 0.01426300685852766\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 177/1000 \t Train Err: 0.03245190531015396 0.0002554329694248736 0.00015877540863584727 0.014040789566934109 0.04714728519320488\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 178/1000 \t Train Err: 0.041977256536483765 0.0001379920431645587 0.00021775254572276026 0.009346511214971542 0.01231997087597847\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 179/1000 \t Train Err: 0.03204528987407684 9.879021672531962e-05 0.0003249702858738601 0.01157579105347395 0.032357051968574524\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 180/1000 \t Train Err: 0.030494939535856247 5.318508192431182e-05 0.0001941232185345143 0.01810402423143387 0.03970681130886078\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 181/1000 \t Train Err: 0.03218014910817146 0.0006771694752387702 0.00016017410962376744 0.009343582205474377 0.03883388638496399\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 182/1000 \t Train Err: 0.03439633548259735 8.35641985759139e-05 0.0008245277567766607 0.026584787294268608 0.04398134723305702\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 183/1000 \t Train Err: 0.03437873348593712 0.0003609458508435637 0.0001779649028321728 0.011353005655109882 0.022845609113574028\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 184/1000 \t Train Err: 0.029747625812888145 0.0005436437204480171 0.0006577487220056355 0.02241024561226368 0.03428546339273453\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 185/1000 \t Train Err: 0.03702589124441147 0.001318416208960116 0.0003100436006207019 0.006715381983667612 0.012282946147024632\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 186/1000 \t Train Err: 0.022873839363455772 0.00014723198546562344 0.0011230326490476727 0.015240584500133991 0.020318059250712395\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 187/1000 \t Train Err: 0.02688605710864067 0.00014236278366297483 0.00022871489636600018 0.017673302441835403 0.029775310307741165\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 188/1000 \t Train Err: 0.026118462905287743 3.561102130333893e-05 0.00015905339387245476 0.010189777240157127 0.01750301755964756\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 189/1000 \t Train Err: 0.031032036989927292 6.626216054428369e-05 0.0001567322324262932 0.014112681150436401 0.03739028424024582\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 190/1000 \t Train Err: 0.02181524783372879 0.00023006339324638247 0.0005933582433499396 0.010394968092441559 0.00304942368529737\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 191/1000 \t Train Err: 0.027184132486581802 0.004482210148125887 0.0019434703281149268 0.010643635876476765 0.018308935686945915\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 192/1000 \t Train Err: 0.024762475863099098 0.00021026897593401372 0.00068919628392905 0.013836676254868507 0.00759515818208456\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 193/1000 \t Train Err: 0.027485240250825882 5.092989158583805e-05 0.00041440658969804645 0.007257797755300999 0.029338931664824486\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 194/1000 \t Train Err: 0.022330544888973236 0.0004199454269837588 0.0004707501211669296 0.019549088552594185 0.0023543694987893105\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 195/1000 \t Train Err: 0.022276023402810097 6.339305400615558e-05 0.00018617554451338947 0.012746947817504406 0.02517220377922058\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 196/1000 \t Train Err: 0.02131063863635063 6.503217446152121e-05 0.0001230579655384645 0.009538630954921246 0.010168644599616528\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 197/1000 \t Train Err: 0.021132078021764755 0.0002956095850095153 8.20207133074291e-05 0.006673957221210003 0.010186844505369663\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 198/1000 \t Train Err: 0.019411560148000717 0.00042201197356916964 0.0009095012792386115 0.013239831663668156 0.018353110179305077\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 199/1000 \t Train Err: 0.03216658905148506 4.9469184887129813e-05 0.0002943709841929376 0.01148073747754097 0.016176927834749222\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 200/1000 \t Train Err: 0.02571428008377552 0.0010263347066938877 0.0005481779226101935 0.016015449538826942 0.01978922262787819\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 201/1000 \t Train Err: 0.026272011920809746 3.759583705686964e-05 9.925611811922863e-05 0.008123734965920448 0.01331140287220478\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 202/1000 \t Train Err: 0.03202397748827934 6.046362250344828e-05 0.000284630514215678 0.011763068847358227 0.04103183373808861\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 203/1000 \t Train Err: 0.012851575389504433 0.00011363952944520861 0.00016413710545748472 0.010319913737475872 0.002792573068290949\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 204/1000 \t Train Err: 0.026115499436855316 0.0005245269276201725 0.00024940649745985866 0.011106887832283974 0.02185993455350399\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 205/1000 \t Train Err: 0.018948335200548172 2.8755090170307085e-05 8.58057028381154e-05 0.005748868454247713 0.015843190252780914\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 206/1000 \t Train Err: 0.018574167042970657 0.00015149133105296642 0.00048761311336420476 0.02577385865151882 0.013325332663953304\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 207/1000 \t Train Err: 0.02082146145403385 0.000238466338487342 0.00023394331219606102 0.006748242769390345 0.019078979268670082\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 208/1000 \t Train Err: 0.014403259381651878 6.994893192313612e-05 0.0008789921994321048 0.010097990743815899 0.00996581930667162\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 209/1000 \t Train Err: 0.020921753719449043 0.0004979136865586042 0.00010895056766457856 0.00956189725548029 0.015378139913082123\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 210/1000 \t Train Err: 0.022038539871573448 0.0001340815215371549 0.00015701379743404686 0.014402204193174839 0.02424498274922371\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 211/1000 \t Train Err: 0.01667097397148609 0.00011134350643260404 0.00014891373575665057 0.00600818358361721 0.009433449245989323\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 212/1000 \t Train Err: 0.01602860540151596 4.298926069168374e-05 7.710252975812182e-05 0.004729713778942823 0.005112841725349426\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 213/1000 \t Train Err: 0.01903412863612175 0.00030714491731487215 0.0001821738842409104 0.005955686792731285 0.022314537316560745\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 214/1000 \t Train Err: 0.015248223207890987 0.00026072614127770066 0.00019891293777618557 0.007477042265236378 0.0036678414326161146\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 215/1000 \t Train Err: 0.014912809245288372 0.00047885856474749744 0.0004725187609437853 0.00442067626863718 0.006714401766657829\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 216/1000 \t Train Err: 0.022600673139095306 5.488095121108927e-05 0.0005237676086835563 0.0061714984476566315 0.03166600689291954\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 217/1000 \t Train Err: 0.022118201479315758 8.072196214925498e-05 0.00010541417577769607 0.00999145582318306 0.022547846660017967\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 218/1000 \t Train Err: 0.01601581647992134 3.067262514377944e-05 0.00018664993694983423 0.004181408789008856 0.013367431238293648\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 219/1000 \t Train Err: 0.01162786316126585 5.145326576894149e-05 0.00015623572107870132 0.006794607732445002 0.001922846888191998\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 220/1000 \t Train Err: 0.012896351516246796 0.0003044742625206709 9.53769194893539e-05 0.006998015101999044 0.003260016907006502\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 221/1000 \t Train Err: 0.023554792627692223 0.00018602493219077587 0.0004451847053132951 0.011047394014894962 0.03432558849453926\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 222/1000 \t Train Err: 0.016859270632267 0.0002703650388866663 0.00010159210069105029 0.0054770237766206264 0.013049724511802197\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 223/1000 \t Train Err: 0.018101878464221954 5.2419105486478657e-05 6.96199931553565e-05 0.005247439723461866 0.011564591899514198\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 224/1000 \t Train Err: 0.015453542582690716 0.00010983451647916809 0.00022485408408101648 0.005099967587739229 0.012966718524694443\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 225/1000 \t Train Err: 0.018901851028203964 4.1673782106954604e-05 0.0015293046599254012 0.00927029736340046 0.015386526472866535\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 226/1000 \t Train Err: 0.019163722172379494 0.0016058672918006778 0.002710553817451 0.003990732133388519 0.013572349213063717\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 227/1000 \t Train Err: 0.0185165386646986 8.810655708657578e-05 0.0001792146940715611 0.007222854066640139 0.020297806710004807\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 228/1000 \t Train Err: 0.020369376987218857 0.0009141123155131936 0.001427489914931357 0.008912032470107079 0.02442525327205658\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 229/1000 \t Train Err: 0.01552458293735981 0.00011080846888944507 6.443277379730716e-05 0.004490815103054047 0.00922235008329153\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 230/1000 \t Train Err: 0.0195521991699934 2.673462404345628e-05 0.0002968674525618553 0.0058441669680178165 0.013992007821798325\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 231/1000 \t Train Err: 0.01986689493060112 4.5369306462816894e-05 4.316281047067605e-05 0.005999550223350525 0.015800679102540016\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 232/1000 \t Train Err: 0.019927091896533966 0.00010001207556342706 0.0014023327967152 0.0054413750767707825 0.027315480634570122\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.05it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 233/1000 \t Train Err: 0.014753343537449837 0.0003090432728640735 8.334611629834399e-05 0.004138716962188482 0.004037454724311829\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 234/1000 \t Train Err: 0.012261731550097466 0.0003884017060045153 0.00043851000373251736 0.00813040230423212 0.007334974128752947\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 235/1000 \t Train Err: 0.011006608605384827 0.0003388413751963526 0.00021402201673481613 0.005710902623832226 0.008260917849838734\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 236/1000 \t Train Err: 0.00900082103908062 0.0020744148641824722 0.0004009941185358912 0.005855425260961056 0.000494154985062778\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 6%|████████████▌ | 2/32 [00:01<00:30, 1.00s/it]"
+ ]
+ }
+ ],
"source": [
"while epoch < NEPOCHS:\n",
" model.train()\n",
@@ -527,18 +4097,17 @@
" optimizer.zero_grad()\n",
" output = model(batch_src, batch_padding_mask)\n",
" loss = criterion(output.squeeze(1), batch_labels)\n",
- " train_loss = loss.item()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
- " train_err.append(train_loss)\n",
+ " train_err.append(loss.item())\n",
" len1.append(criterion(output[batch_labels == 1].squeeze(1), batch_labels[batch_labels == 1]).item())\n",
" len2.append(criterion(output[batch_labels == 2].squeeze(1), batch_labels[batch_labels == 2]).item())\n",
" len3.append(criterion(output[batch_labels == 3].squeeze(1), batch_labels[batch_labels == 3]).item())\n",
" leninf.append(criterion(output[batch_labels == MAX_VTXS].squeeze(1), batch_labels[batch_labels == MAX_VTXS]).item())\n",
" \n",
" with open('loss', 'a') as f:\n",
- " f.write(f\"{train_loss} {len1[-1]} {len2[-1]} {len3[-1]} {leninf[-1]}\\n\")\n",
+ " f.write(f\"{train_err[-1]} {len1[-1]} {len2[-1]} {len3[-1]} {leninf[-1]}\\n\")\n",
" print(f\"Epoch {epoch}/{NEPOCHS} \\t Train Err: {train_err[-1]} {len1[-1]} {len2[-1]} {len3[-1]} {leninf[-1]}\")\n",
"\n",
" epoch += 1\n",
@@ -839,7 +4408,7 @@
" print(f\"Epoch {epoch}/{N_TUNE_EPOCHS} \\t Train Err: {train_loss:.4f}\")\n",
"\n",
" if epoch % 10 == 9:\n",
- " torch.save(model.state_dict(), f\"tune_model_weights_{epoch + 1}.pth\")"
+ " torch.save(model.state_dict(), f\"checkpoints/tune_model_weights_{epoch + 1}.pth\")"
]
},
{