代码如下:
public class ImagePipelineExample {
protected static final Logger log = LoggerFactory.getLogger(ImagePipelineExample.class);
//Images are of format given by allowedExtension -
protected static final String [] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;
protected static final long seed = 12345;
public static final Random randNumGen = new Random(seed);
protected static int height = 50;
protected static int width = 50;
protected static int channels = 3;
public static void main(String[] args) throws Exception {
//DIRECTORY STRUCTURE:
//Images in the dataset have to be organized in directories by class/label.
//In this example there are ten images in three classes
//Here is the directory structure
// parentDir
// / | \
// / | \
// labelB labelB labelC
//
//Set your data up like this so that labels from each label/class live in their own directory
//And these label/class directories live together in the parent directory
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker);
ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker);
File trainDir = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/DataExamples/ImagePipeline/");
InputSplit trainData1=new FileSplit(trainDir);
recordReader.initialize(trainData1);
int outputNum = recordReader.numLabels();
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, 10, 1, outputNum);
while(dataIter.hasNext()){
DataSet ds=dataIter.next();
System.out.println("train:"+ds);
}
recordReader.reset();
System.out.println("train Finished!");
File testDir = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/DataExamples/testlabel/");
InputSplit testData1=new FileSplit(testDir);
recordReader.initialize(testData1);
DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader,10,1,outputNum);//生成测试迭代数据
while (testIter.hasNext()){
DataSet ds = dataIter.next();
System.out.println("test:"+ds);
}
recordReader.reset();
}
}
这里有个测试集的目录,我在这个目录下放了一张花的图片。
就直接报错了。
train Finished!
Labels:[labelA, labelB, labelC, testlabel]
Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 3
at org.nd4j.linalg.util.FeatureUtil.toOutcomeVector(FeatureUtil.java:38)
at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.getDataSet(RecordReaderDataSetIterator.java:234)
at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:186)
at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:389)
at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:52)
at org.deeplearning4j.examples.dataexamples.ImagePipelineExample.main(ImagePipelineExample.java:91)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144)
Process finished with exit code 1
可以看出来,它把测试集的目录 也当成一个output了。
请问,要如何测试单张图片的output呢?