问题
首先,我使用的是 TFX 0.21.2 版和 Tensorflow 2.1 版。
我建立了一个管道,主要以芝加哥出租车为例。执行 Trainer 组件时,我可以在日志中看到以下内容:
信息培训完成。模型写入 /root/airflow/tfx/pipelines/fish/Trainer/Model/9/serving_model_dir
检查上面的目录时,它是空的。我错过了什么?
这是我的 DAG 定义文件(忽略 import 语句):
- _pipeline_name = 'fish'
- _airflow_config = AirflowPipelineConfig(airflow_dag_config = {
- 'schedule_interval': None,
- 'start_date': datetime.datetime(2019, 1, 1),
- })
- _project_root = os.path.join(os.environ['HOME'], 'airflow')
- _data_root = os.path.join(_project_root, 'data', 'fish_data')
- _module_file = os.path.join(_project_root, 'dags', 'fishUtils.py')
- _serving_model_dir = os.path.join(_project_root, 'serving_model', _pipeline_name)
- _tfx_root = os.path.join(_project_root, 'tfx')
- _pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name)
- _metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name,
- 'metadata.db')
- def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
- module_file: Text, serving_model_dir: Text,
- metadata_path: Text,
- direct_num_workers: int) -> pipeline.Pipeline:
- examples = external_input(data_root)
- example_gen = CsvExampleGen(input=examples)
- statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
- infer_schema = SchemaGen(
- statistics=statistics_gen.outputs['statistics'],
- infer_feature_shape=False)
- validate_stats = ExampleValidator(
- statistics=statistics_gen.outputs['statistics'],
- schema=infer_schema.outputs['schema'])
- trainer = Trainer(
- examples=example_gen.outputs['examples'], schema=infer_schema.outputs['schema'],
- module_file=_module_file, train_args= trainer_pb2.TrainArgs(num_steps=10000),
- eval_args= trainer_pb2.EvalArgs(num_steps=5000))
- model_validator = ModelValidator(
- examples=example_gen.outputs['examples'],
- model=trainer.outputs['model'])
- pusher = Pusher(
- model=trainer.outputs['model'],
- model_blessing=model_validator.outputs['blessing'],
- push_destination=pusher_pb2.PushDestination(
- filesystem=pusher_pb2.PushDestination.Filesystem(
- base_directory=_serving_model_dir)))
- return pipeline.Pipeline(
- pipeline_name=_pipeline_name,
- pipeline_root=_pipeline_root,
- components=[
- example_gen,
- statistics_gen,
- infer_schema,
- validate_stats,
- trainer,
- model_validator,
- pusher],
- enable_cache=True,
- metadata_connection_config=metadata.sqlite_metadata_connection_config(
- metadata_path),
- beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers]
- )
- runner = AirflowDagRunner(config = _airflow_config)
- DAG = runner.run(
- _create_pipeline(
- pipeline_name=_pipeline_name,
- pipeline_root=_pipeline_root,
- data_root=_data_root,
- module_file=_module_file,
- serving_model_dir=_serving_model_dir,
- metadata_path=_metadata_path,
- # 0 means auto-detect based on on the number of CPUs available during
- # execution time.
- direct_num_workers=0))
复制代码
这是我的模块文件:
- _DENSE_FLOAT_FEATURE_KEYS = ['length']
- real_valued_columns = [tf.feature_column.numeric_column('length')]
- def _eval_input_receiver_fn():
- serialized_tf_example = tf.compat.v1.placeholder(
- dtype=tf.string, shape=[None], name='input_example_tensor')
- features = tf.io.parse_example(
- serialized=serialized_tf_example,
- features={
- 'length': tf.io.FixedLenFeature([], tf.float32),
- 'label': tf.io.FixedLenFeature([], tf.int64),
- })
- receiver_tensors = {'examples': serialized_tf_example}
- return tfma.export.EvalInputReceiver(
- features={'length' : features['length']},
- receiver_tensors=receiver_tensors,
- labels= features['label'],
- )
- def parser(serialized_example):
- features = tf.io.parse_single_example(
- serialized_example,
- features={
- 'length': tf.io.FixedLenFeature([], tf.float32),
- 'label': tf.io.FixedLenFeature([], tf.int64),
- })
- return ({'length' : features['length']}, features['label'])
- def _input_fn(filenames):
- # TFRecordDataset doesn't directly accept paths with wildcards
- filenames = tf.data.Dataset.list_files(filenames)
- dataset = tf.data.TFRecordDataset(filenames, 'GZIP')
- dataset = dataset.map(parser)
- dataset = dataset.shuffle(2000)
- dataset = dataset.batch(40)
- dataset = dataset.repeat(10)
- return dataset
- def trainer_fn(trainer_fn_args, schema):
- estimator = tf.estimator.LinearClassifier(feature_columns=real_valued_columns)
- train_input_fn = lambda: _input_fn(trainer_fn_args.train_files)
- train_spec = tf.estimator.TrainSpec(
- train_input_fn,
- max_steps=trainer_fn_args.train_steps)
- eval_input_fn = lambda: _input_fn(trainer_fn_args.eval_files)
- eval_spec = tf.estimator.EvalSpec(
- eval_input_fn,
- steps=trainer_fn_args.eval_steps,
- name='fish-eval')
- receiver_fn = lambda: _eval_input_receiver_fn()
- return {
- 'estimator': estimator,
- 'train_spec': train_spec,
- 'eval_spec': eval_spec,
- 'eval_input_receiver_fn': receiver_fn
- }
复制代码
在此先感谢您的帮助!
回答
为遇到与我相同问题的任何人发布解决方案。
模型没有写入文件系统的原因是估计器需要一个配置参数来知道在哪里写入模型。
以下对 trainer_fn 函数的修改应该可以解决问题。
- run_config = tf.estimator.RunConfig(save_checkpoints_steps=999, keep_checkpoint_max=1)
- run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
- estimator=tf.estimator.LinearClassifier(feature_columns=real_valued_columns,config=run_config)
复制代码
|