Tensorflow入门

win10 安装Tensorflow

安装tensorflow的时候一定要看一下版本支持情况,我准备使用tensorflow2.1,根据readme我们显卡驱动要用CUDA 10.1 and cuDNN 7.6.

安装 anaconda

安装nvidia cuda

CUDA是NVIDIA推出的运算平台,CuDNN是专门针对Deep Learning框架设计的一套GPU计算加速方案。在安装之前要查询下最新TensorFLow发行版支持到了哪个版本。

查看是否安装成功,打印版本号
C:\Users\14784>nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Fri_Feb__8_19:08:26_Pacific_Standard_Time_2019
Cuda compilation tools, release 10.1, V10.1.105

也可以通过Nvidia控制面板查看是否安装成功

65

安装nvidia cudnn

下载地址:https://developer.nvidia.com/rdp/cudnn-download

  • 我下载的是cudnn-10.1-windows10-x64-v7.6.5.32.zip,与cuda一致

  • 解压缩有三个文件夹 bin ,include, lib

  • 复制粘贴cuDNN里面的三个文件到CUDA的相应同名文件

  • cuda\bin =>C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\bin

  • cuda\include => C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include

  • cuda\lib\x64 => C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\lib\x64

安装tensorflow

在anaconda里创建名为tensorflow(也可以叫其他名字)的环境

conda create -n tensorflow pip python=3.7

pip python=3.7的意思是在名为tensorflow的环境里搭建版本是3.7的python。
python版本要跟tensorflow指定的一致。
激活conda环境
conda activate tensorflow
我查看了最新的tensorflow是2.1.0,所以安装2.1.0,当然也可以不指定版本
pip install tensorflow-gpu==2.1.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

模型训练

TensorFlow有一个模型集市,上面有各种各样预训练的机器学习模型,用来解决各种各样的共性问题,在这些模型的基础上,使用自己的数据改进模型,这比自己从头训练要高效。

预训练模型地址

参考g3doc/detection_model_zoo

见链接里面有各种预训练模型:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

选择模型

git clone https://github.com/tensorflow/models.git

码云地址
git clone https://gitee.com/yang_zi_0709/tensorflow-models.git
  • official目录:TensorFlow的高级API的示例模型的集合,它们得到良好的维护,支持最新稳定API,经过了充分的测试,并进行过优化,是TensorFlow用户的首选。

  • research目录:研究人员在TensorFlow中实施的模型,它们没有得到官方支持,也不能保证在后续的TensorFlow发布版本中工作,带有一些研究性质。

  • samples目录:包含代码片段和较小的模型,用于演示TensorFlow的功能,包括各种博客文章中提供的代码。

  • tutorials目录:TensorFlow教程中描述的模型集合。

object_detection api安装

编译protobuf

因为我之前在Ubuntu上已经安装了protobuf,所以直接在Ubuntu上编译好,拿到Windows里。

protoc object_detection/protos/*.proto --python_out=.

跑测试代码

python object_detection/builders/model_builder_test.py

66

“ImportError:No module named 'object_detection'”的错误,则说明缺少环境变量。

查看site-packages目录:
python -m site

我的目录是E:\software\anaconda3\envs\tensorflow\Lib\site-packages

在该目录下创建.pth文件,文件名随意,例如我的是:tensorflow_path.pth。
文件中添加下面两行,要根据自己的地址来

D:\github\tensorflow-models\research
D:\github\tensorflow-models\research\slim


重新执行,还是报错slim = tf.contrib.slim module 'tensorflow' has no attribute 'contrib'
是因为tf版本问题
只能用tf v1版本

再创建一个tensorflow1.15.0

conda create -n tf pip python=3.6
conda activate tf
pip install tensorflow-gpu==1.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

同时cuda cudnn 安装10.0版本,但是发现10.0版本cuda死活装不上,放弃不装了,用没有gpu的版本吧

pip uninstall tensorflow-gpu
pip install tensorflow==1.15.0  -i https://pypi.tuna.tsinghua.edu.cn/simple

跑一个目标识别demo

我们切换回tensorflow环境下

在object_detection目录下新建一个object_detect_demo.py名字自己定。

#一定要保存为UTF8的格式哦
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
#import tensorflow as tf
import zipfile
import matplotlib
import cv2

#通过下面两行1.x代码可以转成2.x
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
 
# Matplotlib chooses Xwindows backend by default.
matplotlib.use('Agg')
 
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from utils import label_map_util
from utils import visualization_utils as vis_util
 
 
##################### Download Model,如果本地已下载也可修改成本地路径
# What model to download.
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
 
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
 
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
 
NUM_CLASSES = 90
 
# Download model if not already downloaded
if not os.path.exists(PATH_TO_CKPT):
    print('Downloading model... (This may take over 5 minutes)')
    opener = urllib.request.URLopener()
    opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
    print('Extracting...')
    tar_file = tarfile.open(MODEL_FILE)
    for file in tar_file.getmembers():
        file_name = os.path.basename(file.name)
        if 'frozen_inference_graph.pb' in file_name:
            tar_file.extract(file, os.getcwd())
else:
    print('Model already downloaded.')
 
##################### Load a (frozen) Tensorflow model into memory.
print('Loading model...')
detection_graph = tf.Graph()
 
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
 
##################### Loading label map
print('Loading label map...')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
 
##################### Helper code
def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)
 
##################### Detection
# 测试图片的路径,可以根据自己的实际情况修改
TEST_IMAGE_PATH = 'test_images/image1.jpg'
 
# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)
 
print('Detecting...')
with detection_graph.as_default():
  with tf.Session(graph=detection_graph) as sess:
    print(TEST_IMAGE_PATH)
    image = Image.open(TEST_IMAGE_PATH)
    image_np = load_image_into_numpy_array(image)
    image_np_expanded = np.expand_dims(image_np, axis=0)
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    scores = detection_graph.get_tensor_by_name('detection_scores:0')
    classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    # Actual detection.
    (boxes, scores, classes, num_detections) = sess.run(
        [boxes, scores, classes, num_detections],
        feed_dict={image_tensor: image_np_expanded})
 
    # Visualization of the results of a detection.
    vis_util.visualize_boxes_and_labels_on_image_array(
        image_np,
        np.squeeze(boxes),
        np.squeeze(classes).astype(np.int32),
        np.squeeze(scores),
        category_index,
        use_normalized_coordinates=True,
        line_thickness=8)
    print(TEST_IMAGE_PATH.split('.')[0]+'_labeled.jpg')
    plt.figure(figsize=IMAGE_SIZE, dpi=300)
    # 不知道为什么,在我的机器上没显示出图片,有知道的朋友指点下,谢谢
    plt.imshow(image_np)
    # 保存标记图片
    plt.savefig(TEST_IMAGE_PATH.split('.')[0] + '_labeled.jpg')

运行

python object_detect_demo.py

可能报错缺少一些库手动安装就行

 pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
 
pip install opencv-python  -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install Pillow -i https://pypi.tuna.tsinghua.edu.cn/simple

运行的时候还可能报其他问题但是用下面两行都能解决

#通过下面两行1.x代码可以转成2.x
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

训练自己目标检测模型

选择预训练模型

我们选择

ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03

在object_detection 目录下新建一个train目录(名字自己定),将上面的模型拷贝进去。

在object_detection/samples/configs下找到模型多对应的.config文件,将其拷贝到自己的目录中,并修改该文件。

将num_classes: 96中的96修改为自己索要训练的类别数目,若训练的类别有2类,则num_classes: 2

将fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt 的路径设置为自己的训练模型所存放的路径,例如 fine_tune_checkpoint: ./model.ckpt

input_path的路径设置为自己的训练数据集的路径,例如input_path: "./new_data_set/all_train_1.record",一共有两处

label_map_path的路径设置刚才创建label_map.pbtxt所在的路径,例如 label_map_path:"label_map.pbtxt"

num_examples: 8000 修改为自己数据集的验证集的大小,例如num_examples: 1000

input_path设置为验证集的路径,例如: input_path"./data_set/all_vaild_1.record"

评论

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×