weixin_39615741
weixin_39615741
2021-01-12 12:38

TF 2.0 API for using the embedding projector

Preparing embeddings for projector with tensorflow2.

tensorflow1 code would look something like that:

embeddings = tf.compat.v1.Variable(latent_data, name='embeddings')
CHECKPOINT_FILE = TENSORBOARD_DIR + '/model.ckpt'
# Write summaries for tensorboard
with tf.compat.v1.Session() as sess:
    saver = tf.compat.v1.train.Saver([embeddings])
    sess.run(embeddings.initializer)
    saver.save(sess, CHECKPOINT_FILE)
    config = projector.ProjectorConfig()
    embedding = config.embeddings.add()
    embedding.tensor_name = embeddings.name
    embedding.metadata_path = TENSORBOARD_METADATA_FILE

projector.visualize_embeddings(tf.summary.FileWriter(TENSORBOARD_DIR), config)

when using eager mode in tensorflow2 this should (?) look somehow like this:

embeddings = tf.Variable(latent_data, name='embeddings')
CHECKPOINT_FILE = TENSORBOARD_DIR + '/model.ckpt'
ckpt = tf.train.Checkpoint(embeddings=embeddings)
ckpt.save(CHECKPOINT_FILE)

config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embeddings.name
embedding.metadata_path = TENSORBOARD_METADATA_FILE

writer = tf.summary.create_file_writer(TENSORBOARD_DIR)
projector.visualize_embeddings(writer, config)

however, there are 2 issues:

  • the writer created with tf.summary.create_file_writer does not have the function get_logdir() required by projector.visualize_embeddings, a simple workaround is to patch the visualize_embeddings function to take the logdir as parameter.
  • the checkpoint format has changed, when reading the checkpoint with load_checkpoint (which seems to be the tensorboard way of loading the file), the variable names change. e.g. embeddings changes to something like embeddings/.ATTRIBUTES/VARIABLE_VALUE (also there are additional variables in the map extracted by get_variable_to_shape_map()but they are empty anyways).

the second issue was solved with the following quick-and-dirty workaround (and logdir is now a parameter of visualize_embeddings())

embeddings = tf.Variable(latent_data, name='embeddings')
CHECKPOINT_FILE = TENSORBOARD_DIR + '/model.ckpt'
ckpt = tf.train.Checkpoint(embeddings=embeddings)
ckpt.save(CHECKPOINT_FILE)

reader = tf.train.load_checkpoint(TENSORBOARD_DIR)
map = reader.get_variable_to_shape_map()
key_to_use = ""
for key in map:
    if "embeddings" in key:
        key_to_use = key

config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = key_to_use
embedding.metadata_path = TENSORBOARD_METADATA_FILE

writer = tf.summary.create_file_writer(TENSORBOARD_DIR)
projector.visualize_embeddings(writer, config,TENSORBOARD_DIR)

I did not find any examples on how to use tensorflow2 to directly write the embeddings for tensorboard, so I am not sure if this is the right way, but if it is, then those two issues would need to be addressed.

dump of diagnose_tensorboard.py

Diagnostics

Diagnostics output
--- check: autoidentify INFO: diagnose_tensorboard.py version 393931f9685bd7e0f3898d7dcdf28819fef54c43 --- check: general INFO: sys.version_info: sys.version_info(major=3, minor=7, micro=3, releaselevel='final', serial=0) INFO: os.name: posix INFO: os.uname(): posix.uname_result(sysname='Darwin', nodename='MBPT', release='18.6.0', version='Darwin Kernel Version 18.6.0: Thu Apr 25 23:16:27 PDT 2019; root:xnu-4903.261.4~2/RELEASE_X86_64', machine='x86_64') INFO: sys.getwindowsversion(): N/A --- check: package_management INFO: has conda-meta: True INFO: $VIRTUAL_ENV: None --- check: installed_packages INFO: installed: tb-nightly==1.14.0a20190603 INFO: installed: tensorflow==2.0.0b1 INFO: installed: tf-estimator-nightly==1.14.0.dev2019060501 --- check: tensorboard_python_version INFO: tensorboard.version.VERSION: '1.14.0a20190603' --- check: tensorflow_python_version INFO: tensorflow.__version__: '2.0.0-beta1' INFO: tensorflow.__git_version__: 'v2.0.0-beta0-16-g1d91213fe7' --- check: tensorboard_binary_path INFO: which tensorboard: b'/USER_DIR/anaconda3/envs/TF20/bin/tensorboard\n' --- check: readable_fqdn INFO: socket.getfqdn(): '104.1.168.192.in-addr.arpa' --- check: stat_tensorboardinfo INFO: directory: /var/folders/zv/0ywdhk0s55q2770ygg2xbty40000gn/T/.tensorboard-info INFO: .tensorboard-info directory does not exist --- check: source_trees_without_genfiles INFO: tensorboard_roots (1): ['/USER_DIR/anaconda3/envs/TF20/lib/python3.7/site-packages']; bad_roots (0): [] --- check: full_pip_freeze INFO: pip freeze --all: absl-py==0.7.1 astor==0.8.0 certifi==2019.6.16 gast==0.2.2 google-pasta==0.1.7 grpcio==1.22.0 h5py==2.9.0 Keras-Applications==1.0.8 Keras-Preprocessing==1.1.0 Markdown==3.1.1 numpy==1.16.4 pandas==0.25.0 pip==19.2.1 protobuf==3.9.0 python-dateutil==2.8.0 pytz==2019.1 setuptools==41.0.1 six==1.12.0 tb-nightly==1.14.0a20190603 tensorflow==2.0.0b1 termcolor==1.1.0 tf-estimator-nightly==1.14.0.dev2019060501 Werkzeug==0.15.5 wheel==0.33.4 wrapt==1.11.2

该提问来源于开源项目:tensorflow/tensorboard

  • 点赞
  • 回答
  • 收藏
  • 复制链接分享

15条回答

为你推荐

换一换