|
86 | 86 | "!pip install tensorflow==2.15.0 tensorflow-quantum==0.7.3 tensorboard_plugin_profile==2.15.0"
|
87 | 87 | ]
|
88 | 88 | },
|
89 |
| - { |
90 |
| - "cell_type": "code", |
91 |
| - "execution_count": 0, |
92 |
| - "metadata": { |
93 |
| - "colab": {}, |
94 |
| - "colab_type": "code", |
95 |
| - "id": "4Ql5PW-ACO0J" |
96 |
| - }, |
97 |
| - "outputs": [], |
98 |
| - "source": [ |
99 |
| - "# Update package resources to account for version changes.\n", |
100 |
| - "import importlib, pkg_resources\n", |
101 |
| - "importlib.reload(pkg_resources)" |
102 |
| - ] |
103 |
| - }, |
104 |
| - { |
| 89 | + { |
| 90 | + "cell_type": "code", |
| 91 | + "execution_count": 0, |
| 92 | + "metadata": { |
| 93 | + "colab": {}, |
| 94 | + "colab_type": "code", |
| 95 | + "id": "4Ql5PW-ACO0J" |
| 96 | + }, |
| 97 | + "outputs": [], |
| 98 | + "source": [ |
| 99 | + "# Update package resources to account for version changes.\n", |
| 100 | + "import importlib, pkg_resources\n", |
| 101 | + "\n", |
| 102 | + "importlib.reload(pkg_resources)" |
| 103 | + ] |
| 104 | + }, |
| 105 | + { |
105 | 106 | "cell_type": "code",
|
106 | 107 | "execution_count": null,
|
107 | 108 | "metadata": {
|
|
159 | 160 | " qubits, depth=2)\n",
|
160 | 161 | " return random_circuit\n",
|
161 | 162 | "\n",
|
| 163 | + "\n", |
162 | 164 | "def generate_data(circuit, n_samples):\n",
|
163 | 165 | " \"\"\"Draw n_samples samples from circuit into a tf.Tensor.\"\"\"\n",
|
164 |
| - " return tf.squeeze(tfq.layers.Sample()(circuit, repetitions=n_samples).to_tensor())" |
| 166 | + " return tf.squeeze(tfq.layers.Sample()(circuit,\n", |
| 167 | + " repetitions=n_samples).to_tensor())" |
165 | 168 | ]
|
166 | 169 | },
|
167 | 170 | {
|
|
270 | 273 | " \"\"\"Convert tensor of bitstrings to tensor of ints.\"\"\"\n",
|
271 | 274 | " sigs = tf.constant([1 << i for i in range(N_QUBITS)], dtype=tf.int32)\n",
|
272 | 275 | " rounded_bits = tf.clip_by_value(tf.math.round(\n",
|
273 |
| - " tf.cast(bits, dtype=tf.dtypes.float32)), clip_value_min=0, clip_value_max=1)\n", |
274 |
| - " return tf.einsum('jk,k->j', tf.cast(rounded_bits, dtype=tf.dtypes.int32), sigs)\n", |
| 276 | + " tf.cast(bits, dtype=tf.dtypes.float32)),\n", |
| 277 | + " clip_value_min=0,\n", |
| 278 | + " clip_value_max=1)\n", |
| 279 | + " return tf.einsum('jk,k->j', tf.cast(rounded_bits, dtype=tf.dtypes.int32),\n", |
| 280 | + " sigs)\n", |
| 281 | + "\n", |
275 | 282 | "\n",
|
276 | 283 | "@tf.function\n",
|
277 | 284 | "def xeb_fid(bits):\n",
|
278 | 285 | " \"\"\"Compute linear XEB fidelity of bitstrings.\"\"\"\n",
|
279 | 286 | " final_probs = tf.squeeze(\n",
|
280 |
| - " tf.abs(tfq.layers.State()(REFERENCE_CIRCUIT).to_tensor()) ** 2)\n", |
| 287 | + " tf.abs(tfq.layers.State()(REFERENCE_CIRCUIT).to_tensor())**2)\n", |
281 | 288 | " nums = bits_to_ints(bits)\n",
|
282 |
| - " return (2 ** N_QUBITS) * tf.reduce_mean(tf.gather(final_probs, nums)) - 1.0" |
| 289 | + " return (2**N_QUBITS) * tf.reduce_mean(tf.gather(final_probs, nums)) - 1.0" |
283 | 290 | ]
|
284 | 291 | },
|
285 | 292 | {
|
|
334 | 341 | "outputs": [],
|
335 | 342 | "source": [
|
336 | 343 | "LATENT_DIM = 100\n",
|
| 344 | + "\n", |
| 345 | + "\n", |
337 | 346 | "def make_generator_model():\n",
|
338 | 347 | " \"\"\"Construct generator model.\"\"\"\n",
|
339 | 348 | " model = tf.keras.Sequential()\n",
|
|
345 | 354 | "\n",
|
346 | 355 | " return model\n",
|
347 | 356 | "\n",
|
| 357 | + "\n", |
348 | 358 | "def make_discriminator_model():\n",
|
349 | 359 | " \"\"\"Constrcut discriminator model.\"\"\"\n",
|
350 | 360 | " model = tf.keras.Sequential()\n",
|
|
387 | 397 | "outputs": [],
|
388 | 398 | "source": [
|
389 | 399 | "cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)\n",
|
| 400 | + "\n", |
| 401 | + "\n", |
390 | 402 | "def discriminator_loss(real_output, fake_output):\n",
|
391 | 403 | " \"\"\"Compute discriminator loss.\"\"\"\n",
|
392 | 404 | " real_loss = cross_entropy(tf.ones_like(real_output), real_output)\n",
|
393 | 405 | " fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)\n",
|
394 | 406 | " total_loss = real_loss + fake_loss\n",
|
395 | 407 | " return total_loss\n",
|
396 | 408 | "\n",
|
| 409 | + "\n", |
397 | 410 | "def generator_loss(fake_output):\n",
|
398 | 411 | " \"\"\"Compute generator loss.\"\"\"\n",
|
399 | 412 | " return cross_entropy(tf.ones_like(fake_output), fake_output)\n",
|
400 | 413 | "\n",
|
| 414 | + "\n", |
401 | 415 | "generator_optimizer = tf.keras.optimizers.Adam(1e-4)\n",
|
402 | 416 | "discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)"
|
403 | 417 | ]
|
|
410 | 424 | },
|
411 | 425 | "outputs": [],
|
412 | 426 | "source": [
|
413 |
| - "BATCH_SIZE=256\n", |
| 427 | + "BATCH_SIZE = 256\n", |
| 428 | + "\n", |
414 | 429 | "\n",
|
415 | 430 | "@tf.function\n",
|
416 | 431 | "def train_step(images):\n",
|
|
425 | 440 | " gen_loss = generator_loss(fake_output)\n",
|
426 | 441 | " disc_loss = discriminator_loss(real_output, fake_output)\n",
|
427 | 442 | "\n",
|
428 |
| - " gradients_of_generator = gen_tape.gradient(\n", |
429 |
| - " gen_loss, generator.trainable_variables)\n", |
| 443 | + " gradients_of_generator = gen_tape.gradient(gen_loss,\n", |
| 444 | + " generator.trainable_variables)\n", |
430 | 445 | " gradients_of_discriminator = disc_tape.gradient(\n",
|
431 | 446 | " disc_loss, discriminator.trainable_variables)\n",
|
432 | 447 | "\n",
|
|
480 | 495 | "def train(dataset, epochs, start_epoch=1):\n",
|
481 | 496 | " \"\"\"Launch full training run for the given number of epochs.\"\"\"\n",
|
482 | 497 | " # Log original training distribution.\n",
|
483 |
| - " tf.summary.histogram('Training Distribution', data=bits_to_ints(dataset), step=0)\n", |
| 498 | + " tf.summary.histogram('Training Distribution',\n", |
| 499 | + " data=bits_to_ints(dataset),\n", |
| 500 | + " step=0)\n", |
484 | 501 | "\n",
|
485 |
| - " batched_data = tf.data.Dataset.from_tensor_slices(dataset).shuffle(N_SAMPLES).batch(512)\n", |
| 502 | + " batched_data = tf.data.Dataset.from_tensor_slices(dataset).shuffle(\n", |
| 503 | + " N_SAMPLES).batch(512)\n", |
486 | 504 | " t = time.time()\n",
|
487 | 505 | " for epoch in range(start_epoch, start_epoch + epochs):\n",
|
488 | 506 | " for i, image_batch in enumerate(batched_data):\n",
|
489 | 507 | " # Log batch-wise loss.\n",
|
490 | 508 | " gl, dl = train_step(image_batch)\n",
|
491 |
| - " tf.summary.scalar(\n", |
492 |
| - " 'Generator loss', data=gl, step=epoch * len(batched_data) + i)\n", |
493 |
| - " tf.summary.scalar(\n", |
494 |
| - " 'Discriminator loss', data=dl, step=epoch * len(batched_data) + i)\n", |
| 509 | + " tf.summary.scalar('Generator loss',\n", |
| 510 | + " data=gl,\n", |
| 511 | + " step=epoch * len(batched_data) + i)\n", |
| 512 | + " tf.summary.scalar('Discriminator loss',\n", |
| 513 | + " data=dl,\n", |
| 514 | + " step=epoch * len(batched_data) + i)\n", |
495 | 515 | "\n",
|
496 | 516 | " # Log full dataset XEB Fidelity and generated distribution.\n",
|
497 | 517 | " generated_samples = generator(tf.random.normal([N_SAMPLES, 100]))\n",
|
498 |
| - " tf.summary.scalar(\n", |
499 |
| - " 'Generator XEB Fidelity Estimate', data=xeb_fid(generated_samples), step=epoch)\n", |
500 |
| - " tf.summary.histogram(\n", |
501 |
| - " 'Generator distribution', data=bits_to_ints(generated_samples), step=epoch)\n", |
| 518 | + " tf.summary.scalar('Generator XEB Fidelity Estimate',\n", |
| 519 | + " data=xeb_fid(generated_samples),\n", |
| 520 | + " step=epoch)\n", |
| 521 | + " tf.summary.histogram('Generator distribution',\n", |
| 522 | + " data=bits_to_ints(generated_samples),\n", |
| 523 | + " step=epoch)\n", |
502 | 524 | " # Log new samples drawn from this particular random circuit.\n",
|
503 | 525 | " random_new_distribution = generate_data(REFERENCE_CIRCUIT, N_SAMPLES)\n",
|
504 |
| - " tf.summary.histogram(\n", |
505 |
| - " 'New round of True samples', data=bits_to_ints(random_new_distribution), step=epoch)\n", |
| 526 | + " tf.summary.histogram('New round of True samples',\n", |
| 527 | + " data=bits_to_ints(random_new_distribution),\n", |
| 528 | + " step=epoch)\n", |
506 | 529 | "\n",
|
507 | 530 | " if epoch % 10 == 0:\n",
|
508 | 531 | " print('Epoch {}, took {}(s)'.format(epoch, time.time() - t))\n",
|
|
0 commit comments