This is an automated email from the ASF dual-hosted git repository. tsato pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/camel.git
The following commit(s) were added to refs/heads/main by this push: new 4c9c11f9fd2 camel-djl - Add type converter for DJL types 4c9c11f9fd2 is described below commit 4c9c11f9fd21c72f8e4c3463c8b63437533e3f6a Author: Tadayoshi Sato <sato.tadayo...@gmail.com> AuthorDate: Tue Jul 23 14:10:50 2024 +0900 camel-djl - Add type converter for DJL types --- .../camel/component/djl/DJLConverterLoader.java | 72 ++++++++++++++ .../services/org/apache/camel/TypeConverterLoader | 2 + .../camel-djl/src/main/docs/djl-component.adoc | 3 - .../apache/camel/component/djl/DJLConverter.java | 103 +++++++++++++++++++++ .../component/djl/model/AbstractPredictor.java | 2 - .../djl/model/audio/CustomAudioPredictor.java | 50 ++-------- .../djl/model/audio/ZooAudioPredictor.java | 48 ++-------- .../djl/model/cv/AbstractCvZooPredictor.java | 50 ++-------- .../component/djl/model/cv/CustomCvPredictor.java | 45 ++------- .../component/djl/CvActionRecognitionTest.java | 1 - .../djl/CvImageClassificationLocalTest.java | 1 - .../component/djl/CvImageClassificationTest.java | 1 - .../component/djl/CvImageEnhancementLocalTest.java | 1 - .../component/djl/CvInstanceSegmentationTest.java | 1 - .../camel/component/djl/CvObjectDetectionTest.java | 1 - .../camel/component/djl/CvPoseEstimationTest.java | 13 +-- .../component/djl/CvSemanticSegmentationTest.java | 1 - 17 files changed, 203 insertions(+), 192 deletions(-) diff --git a/components/camel-ai/camel-djl/src/generated/java/org/apache/camel/component/djl/DJLConverterLoader.java b/components/camel-ai/camel-djl/src/generated/java/org/apache/camel/component/djl/DJLConverterLoader.java new file mode 100644 index 00000000000..871c79ac82d --- /dev/null +++ b/components/camel-ai/camel-djl/src/generated/java/org/apache/camel/component/djl/DJLConverterLoader.java @@ -0,0 +1,72 @@ +/* Generated by camel build tools - do NOT edit this file! */ +package org.apache.camel.component.djl; + +import javax.annotation.processing.Generated; + +import org.apache.camel.CamelContext; +import org.apache.camel.CamelContextAware; +import org.apache.camel.DeferredContextBinding; +import org.apache.camel.Exchange; +import org.apache.camel.TypeConversionException; +import org.apache.camel.TypeConverterLoaderException; +import org.apache.camel.spi.TypeConverterLoader; +import org.apache.camel.spi.TypeConverterRegistry; +import org.apache.camel.support.SimpleTypeConverter; +import org.apache.camel.support.TypeConverterSupport; +import org.apache.camel.util.DoubleMap; + +/** + * Generated by camel build tools - do NOT edit this file! + */ +@Generated("org.apache.camel.maven.packaging.TypeConverterLoaderGeneratorMojo") +@SuppressWarnings("unchecked") +@DeferredContextBinding +public final class DJLConverterLoader implements TypeConverterLoader, CamelContextAware { + + private CamelContext camelContext; + + public DJLConverterLoader() { + } + + @Override + public void setCamelContext(CamelContext camelContext) { + this.camelContext = camelContext; + } + + @Override + public CamelContext getCamelContext() { + return camelContext; + } + + @Override + public void load(TypeConverterRegistry registry) throws TypeConverterLoaderException { + registerConverters(registry); + } + + private void registerConverters(TypeConverterRegistry registry) { + addTypeConverter(registry, ai.djl.modality.audio.Audio.class, byte[].class, false, + (type, exchange, value) -> org.apache.camel.component.djl.DJLConverter.toAudio((byte[]) value)); + addTypeConverter(registry, ai.djl.modality.audio.Audio.class, java.io.File.class, false, + (type, exchange, value) -> org.apache.camel.component.djl.DJLConverter.toAudio((java.io.File) value)); + addTypeConverter(registry, ai.djl.modality.audio.Audio.class, java.io.InputStream.class, false, + (type, exchange, value) -> org.apache.camel.component.djl.DJLConverter.toAudio((java.io.InputStream) value)); + addTypeConverter(registry, ai.djl.modality.audio.Audio.class, java.nio.file.Path.class, false, + (type, exchange, value) -> org.apache.camel.component.djl.DJLConverter.toAudio((java.nio.file.Path) value)); + addTypeConverter(registry, ai.djl.modality.cv.Image.class, ai.djl.modality.cv.output.DetectedObjects.DetectedObject.class, false, + (type, exchange, value) -> org.apache.camel.component.djl.DJLConverter.toImage((ai.djl.modality.cv.output.DetectedObjects.DetectedObject) value, exchange)); + addTypeConverter(registry, ai.djl.modality.cv.Image.class, byte[].class, false, + (type, exchange, value) -> org.apache.camel.component.djl.DJLConverter.toImage((byte[]) value)); + addTypeConverter(registry, ai.djl.modality.cv.Image.class, java.io.File.class, false, + (type, exchange, value) -> org.apache.camel.component.djl.DJLConverter.toImage((java.io.File) value)); + addTypeConverter(registry, ai.djl.modality.cv.Image.class, java.io.InputStream.class, false, + (type, exchange, value) -> org.apache.camel.component.djl.DJLConverter.toImage((java.io.InputStream) value)); + addTypeConverter(registry, ai.djl.modality.cv.Image.class, java.nio.file.Path.class, false, + (type, exchange, value) -> org.apache.camel.component.djl.DJLConverter.toImage((java.nio.file.Path) value)); + addTypeConverter(registry, ai.djl.modality.cv.Image[].class, ai.djl.modality.cv.output.DetectedObjects.class, false, + (type, exchange, value) -> org.apache.camel.component.djl.DJLConverter.toImages((ai.djl.modality.cv.output.DetectedObjects) value, exchange)); + } + + private static void addTypeConverter(TypeConverterRegistry registry, Class<?> toType, Class<?> fromType, boolean allowNull, SimpleTypeConverter.ConversionMethod method) { + registry.addTypeConverter(toType, fromType, new SimpleTypeConverter(allowNull, method)); + } +} diff --git a/components/camel-ai/camel-djl/src/generated/resources/META-INF/services/org/apache/camel/TypeConverterLoader b/components/camel-ai/camel-djl/src/generated/resources/META-INF/services/org/apache/camel/TypeConverterLoader new file mode 100644 index 00000000000..1b15673a8f8 --- /dev/null +++ b/components/camel-ai/camel-djl/src/generated/resources/META-INF/services/org/apache/camel/TypeConverterLoader @@ -0,0 +1,2 @@ +# Generated by camel build tools - do NOT edit this file! +org.apache.camel.component.djl.DJLConverterLoader diff --git a/components/camel-ai/camel-djl/src/main/docs/djl-component.adoc b/components/camel-ai/camel-djl/src/main/docs/djl-component.adoc index e4a56268ea3..ad51deca03d 100644 --- a/components/camel-ai/camel-djl/src/main/docs/djl-component.adoc +++ b/components/camel-ai/camel-djl/src/main/docs/djl-component.adoc @@ -505,7 +505,6 @@ More information about https://docs.djl.ai/engines/mxnet/index.html[MXNet engine [source,java] ---- from("file:/data/mnist/0/10.png") - .convertBodyTo(byte[].class) .to("djl:cv/image_classification?artifactId=ai.djl.mxnet:mlp:0.0.1"); ---- @@ -513,7 +512,6 @@ from("file:/data/mnist/0/10.png") [source,java] ---- from("file:/data/mnist/0/10.png") - .convertBodyTo(byte[].class) .to("djl:cv/image_classification?artifactId=ai.djl.mxnet:mlp:0.0.1"); ---- @@ -537,7 +535,6 @@ context.getRegistry().bind("MyModel", model); context.getRegistry().bind("MyTranslator", translator); from("file:/data/mnist/0/10.png") - .convertBodyTo(byte[].class) .to("djl:cv/image_classification?model=MyModel&translator=MyTranslator"); ---- diff --git a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/DJLConverter.java b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/DJLConverter.java new file mode 100644 index 00000000000..1533d6c2418 --- /dev/null +++ b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/DJLConverter.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.camel.component.djl; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; + +import ai.djl.modality.audio.Audio; +import ai.djl.modality.audio.AudioFactory; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import org.apache.camel.Converter; +import org.apache.camel.Exchange; + +/** + * Converter methods to convert from / to DJL types. + */ +@Converter(generateLoader = true) +public class DJLConverter { + + @Converter + public static Image toImage(byte[] bytes) throws IOException { + return toImage(new ByteArrayInputStream(bytes)); + } + + @Converter + public static Image toImage(File file) throws IOException { + return toImage(new FileInputStream(file)); + } + + @Converter + public static Image toImage(Path path) throws IOException { + return toImage(Files.newInputStream(path)); + } + + @Converter + public static Image toImage(InputStream inputStream) throws IOException { + return ImageFactory.getInstance().fromInputStream(inputStream); + } + + @Converter + public static Image toImage(DetectedObjects.DetectedObject detectedObject, Exchange exchange) { + if (exchange == null || exchange.getMessage() == null) { + return null; + } + + Rectangle rect = detectedObject.getBoundingBox().getBounds(); + Image image = exchange.getMessage().getHeader(DJLConstants.INPUT, Image.class); + return image.getSubImage( + (int) (rect.getX() * image.getWidth()), + (int) (rect.getY() * image.getHeight()), + (int) (rect.getWidth() * image.getWidth()), + (int) (rect.getHeight() * image.getHeight())); + } + + @Converter + public static Image[] toImages(DetectedObjects detectedObjects, Exchange exchange) { + return detectedObjects.<DetectedObjects.DetectedObject> items().stream() + .map(obj -> toImage(obj, exchange)) + .toArray(Image[]::new); + } + + @Converter + public static Audio toAudio(byte[] bytes) throws IOException { + return toAudio(new ByteArrayInputStream(bytes)); + } + + @Converter + public static Audio toAudio(File file) throws IOException { + return toAudio(new FileInputStream(file)); + } + + @Converter + public static Audio toAudio(Path path) throws IOException { + return toAudio(Files.newInputStream(path)); + } + + @Converter + public static Audio toAudio(InputStream inputStream) throws IOException { + return AudioFactory.newInstance().fromInputStream(inputStream); + } +} diff --git a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/AbstractPredictor.java b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/AbstractPredictor.java index ab3c489feb2..9cc43f34a67 100644 --- a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/AbstractPredictor.java +++ b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/AbstractPredictor.java @@ -21,8 +21,6 @@ import org.apache.camel.component.djl.DJLEndpoint; public abstract class AbstractPredictor { - protected static final String FAILED_TO_TRANSFORM_MESSAGE = "Couldn't transform input into a BufferedImage"; - private final DJLEndpoint endpoint; public AbstractPredictor(DJLEndpoint endpoint) { diff --git a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/CustomAudioPredictor.java b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/CustomAudioPredictor.java index 3b2bb48dd1f..5a10e928d53 100644 --- a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/CustomAudioPredictor.java +++ b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/CustomAudioPredictor.java @@ -16,30 +16,20 @@ */ package org.apache.camel.component.djl.model.audio; -import java.io.ByteArrayInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; - import ai.djl.Model; import ai.djl.inference.Predictor; import ai.djl.modality.audio.Audio; -import ai.djl.modality.audio.AudioFactory; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import org.apache.camel.Exchange; import org.apache.camel.RuntimeCamelException; +import org.apache.camel.TypeConversionException; import org.apache.camel.component.djl.DJLConstants; import org.apache.camel.component.djl.DJLEndpoint; import org.apache.camel.component.djl.model.AbstractPredictor; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class CustomAudioPredictor extends AbstractPredictor { - private static final Logger LOG = LoggerFactory.getLogger(CustomAudioPredictor.class); - protected final String modelName; protected final String translatorName; @@ -51,42 +41,14 @@ public class CustomAudioPredictor extends AbstractPredictor { @Override public void process(Exchange exchange) throws Exception { - Object body = exchange.getIn().getBody(); - String result; - if (body instanceof Audio) { - result = predict(exchange, exchange.getIn().getBody(Audio.class)); - } else if (body instanceof byte[]) { - byte[] bytes = exchange.getIn().getBody(byte[].class); - result = predict(exchange, new ByteArrayInputStream(bytes)); - } else if (body instanceof File) { - result = predict(exchange, exchange.getIn().getBody(File.class)); - } else if (body instanceof InputStream) { - result = predict(exchange, exchange.getIn().getBody(InputStream.class)); - } else { + try { + Audio audio = exchange.getIn().getBody(Audio.class); + String result = predict(exchange, audio); + exchange.getIn().setBody(result); + } catch (TypeConversionException e) { throw new RuntimeCamelException( "Data type is not supported. Body should be ai.djl.modality.audio.Audio, byte[], InputStream or File"); } - exchange.getIn().setBody(result); - } - - protected String predict(Exchange exchange, File input) { - try (InputStream fileInputStream = new FileInputStream(input)) { - Audio audio = AudioFactory.newInstance().fromInputStream(fileInputStream); - return predict(exchange, audio); - } catch (IOException e) { - LOG.error(FAILED_TO_TRANSFORM_MESSAGE); - throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e); - } - } - - protected String predict(Exchange exchange, InputStream input) { - try { - Audio audio = AudioFactory.newInstance().fromInputStream(input); - return predict(exchange, audio); - } catch (IOException e) { - LOG.error(FAILED_TO_TRANSFORM_MESSAGE); - throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e); - } } protected String predict(Exchange exchange, Audio audio) { diff --git a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/ZooAudioPredictor.java b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/ZooAudioPredictor.java index e8750976383..9d369c92208 100644 --- a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/ZooAudioPredictor.java +++ b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/audio/ZooAudioPredictor.java @@ -16,17 +16,12 @@ */ package org.apache.camel.component.djl.model.audio; -import java.io.ByteArrayInputStream; -import java.io.File; -import java.io.FileInputStream; import java.io.IOException; -import java.io.InputStream; import ai.djl.Application; import ai.djl.MalformedModelException; import ai.djl.inference.Predictor; import ai.djl.modality.audio.Audio; -import ai.djl.modality.audio.AudioFactory; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; @@ -35,16 +30,13 @@ import ai.djl.training.util.ProgressBar; import ai.djl.translate.TranslateException; import org.apache.camel.Exchange; import org.apache.camel.RuntimeCamelException; +import org.apache.camel.TypeConversionException; import org.apache.camel.component.djl.DJLConstants; import org.apache.camel.component.djl.DJLEndpoint; import org.apache.camel.component.djl.model.AbstractPredictor; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class ZooAudioPredictor extends AbstractPredictor { - private static final Logger LOG = LoggerFactory.getLogger(ZooAudioPredictor.class); - private final ZooModel<Audio, String> model; public ZooAudioPredictor(DJLEndpoint endpoint) throws ModelNotFoundException, MalformedModelException, IOException { @@ -64,42 +56,14 @@ public class ZooAudioPredictor extends AbstractPredictor { @Override public void process(Exchange exchange) throws Exception { - Object body = exchange.getIn().getBody(); - String result; - if (body instanceof Audio) { - result = predict(exchange, exchange.getIn().getBody(Audio.class)); - } else if (body instanceof byte[]) { - byte[] bytes = exchange.getIn().getBody(byte[].class); - result = predict(exchange, new ByteArrayInputStream(bytes)); - } else if (body instanceof File) { - result = predict(exchange, exchange.getIn().getBody(File.class)); - } else if (body instanceof InputStream) { - result = predict(exchange, exchange.getIn().getBody(InputStream.class)); - } else { + try { + Audio audio = exchange.getIn().getBody(Audio.class); + String result = predict(exchange, audio); + exchange.getIn().setBody(result); + } catch (TypeConversionException e) { throw new RuntimeCamelException( "Data type is not supported. Body should be ai.djl.modality.audio.Audio, byte[], InputStream or File"); } - exchange.getIn().setBody(result); - } - - protected String predict(Exchange exchange, File input) { - try (InputStream fileInputStream = new FileInputStream(input)) { - Audio audio = AudioFactory.newInstance().fromInputStream(fileInputStream); - return predict(exchange, audio); - } catch (IOException e) { - LOG.error(FAILED_TO_TRANSFORM_MESSAGE); - throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e); - } - } - - protected String predict(Exchange exchange, InputStream input) { - try { - Audio audio = AudioFactory.newInstance().fromInputStream(input); - return predict(exchange, audio); - } catch (IOException e) { - LOG.error(FAILED_TO_TRANSFORM_MESSAGE); - throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e); - } } protected String predict(Exchange exchange, Audio audio) { diff --git a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/AbstractCvZooPredictor.java b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/AbstractCvZooPredictor.java index 737183d7692..cfdef670843 100644 --- a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/AbstractCvZooPredictor.java +++ b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/AbstractCvZooPredictor.java @@ -16,29 +16,19 @@ */ package org.apache.camel.component.djl.model.cv; -import java.io.ByteArrayInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; - import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; -import ai.djl.modality.cv.ImageFactory; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.TranslateException; import org.apache.camel.Exchange; import org.apache.camel.RuntimeCamelException; +import org.apache.camel.TypeConversionException; import org.apache.camel.component.djl.DJLConstants; import org.apache.camel.component.djl.DJLEndpoint; import org.apache.camel.component.djl.model.AbstractPredictor; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public abstract class AbstractCvZooPredictor<T> extends AbstractPredictor { - private static final Logger LOG = LoggerFactory.getLogger(AbstractCvZooPredictor.class); - protected ZooModel<Image, T> model; public AbstractCvZooPredictor(DJLEndpoint endpoint) { @@ -47,42 +37,14 @@ public abstract class AbstractCvZooPredictor<T> extends AbstractPredictor { @Override public void process(Exchange exchange) { - Object body = exchange.getIn().getBody(); - T result; - if (body instanceof Image) { - result = predict(exchange, exchange.getIn().getBody(Image.class)); - } else if (body instanceof byte[]) { - byte[] bytes = exchange.getIn().getBody(byte[].class); - result = predict(exchange, new ByteArrayInputStream(bytes)); - } else if (body instanceof File) { - result = predict(exchange, exchange.getIn().getBody(File.class)); - } else if (body instanceof InputStream) { - result = predict(exchange, exchange.getIn().getBody(InputStream.class)); - } else { + try { + Image image = exchange.getIn().getBody(Image.class); + T result = predict(exchange, image); + exchange.getIn().setBody(result); + } catch (TypeConversionException e) { throw new RuntimeCamelException( "Data type is not supported. Body should be ai.djl.modality.cv.Image, byte[], InputStream or File"); } - exchange.getIn().setBody(result); - } - - protected T predict(Exchange exchange, File input) { - try (InputStream fileInputStream = new FileInputStream(input)) { - Image image = ImageFactory.getInstance().fromInputStream(fileInputStream); - return predict(exchange, image); - } catch (IOException e) { - LOG.error(FAILED_TO_TRANSFORM_MESSAGE); - throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e); - } - } - - protected T predict(Exchange exchange, InputStream input) { - try { - Image image = ImageFactory.getInstance().fromInputStream(input); - return predict(exchange, image); - } catch (IOException e) { - LOG.error(FAILED_TO_TRANSFORM_MESSAGE); - throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e); - } } protected T predict(Exchange exchange, Image image) { diff --git a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomCvPredictor.java b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomCvPredictor.java index 3f4905798af..ccf327e3cfe 100644 --- a/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomCvPredictor.java +++ b/components/camel-ai/camel-djl/src/main/java/org/apache/camel/component/djl/model/cv/CustomCvPredictor.java @@ -16,20 +16,14 @@ */ package org.apache.camel.component.djl.model.cv; -import java.io.ByteArrayInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; - import ai.djl.Model; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; -import ai.djl.modality.cv.ImageFactory; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import org.apache.camel.Exchange; import org.apache.camel.RuntimeCamelException; +import org.apache.camel.TypeConversionException; import org.apache.camel.component.djl.DJLConstants; import org.apache.camel.component.djl.DJLEndpoint; import org.apache.camel.component.djl.model.AbstractPredictor; @@ -51,38 +45,13 @@ public class CustomCvPredictor<T> extends AbstractPredictor { @Override public void process(Exchange exchange) throws Exception { - Object body = exchange.getIn().getBody(); - T result; - if (body instanceof byte[]) { - byte[] bytes = exchange.getIn().getBody(byte[].class); - result = predict(exchange, new ByteArrayInputStream(bytes)); - } else if (body instanceof File) { - result = predict(exchange, exchange.getIn().getBody(File.class)); - } else if (body instanceof InputStream) { - result = predict(exchange, exchange.getIn().getBody(InputStream.class)); - } else { - throw new RuntimeCamelException("Data type is not supported. Body should be byte[], InputStream or File"); - } - exchange.getIn().setBody(result); - } - - protected T predict(Exchange exchange, File input) { - try (InputStream fileInputStream = new FileInputStream(input)) { - Image image = ImageFactory.getInstance().fromInputStream(fileInputStream); - return predict(exchange, image); - } catch (IOException e) { - LOG.error(FAILED_TO_TRANSFORM_MESSAGE); - throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e); - } - } - - protected T predict(Exchange exchange, InputStream input) { try { - Image image = ImageFactory.getInstance().fromInputStream(input); - return predict(exchange, image); - } catch (IOException e) { - LOG.error(FAILED_TO_TRANSFORM_MESSAGE); - throw new RuntimeCamelException(FAILED_TO_TRANSFORM_MESSAGE, e); + Image image = exchange.getIn().getBody(Image.class); + T result = predict(exchange, image); + exchange.getIn().setBody(result); + } catch (TypeConversionException e) { + throw new RuntimeCamelException( + "Data type is not supported. Body should be ai.djl.modality.cv.Image, byte[], InputStream or File"); } } diff --git a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvActionRecognitionTest.java b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvActionRecognitionTest.java index 377f023b28a..9f6bef79ae6 100644 --- a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvActionRecognitionTest.java +++ b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvActionRecognitionTest.java @@ -41,7 +41,6 @@ public class CvActionRecognitionTest extends CamelTestSupport { return new RouteBuilder() { public void configure() { from("file:src/test/resources/data/action?recursive=true&noop=true") - .convertBodyTo(byte[].class) .to("djl:cv/action_recognition?artifactId=ai.djl.mxnet:action_recognition:0.0.1") .log("${header.CamelFileName} = ${body}") .to("mock:result"); diff --git a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageClassificationLocalTest.java b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageClassificationLocalTest.java index 8c3c02180fc..26332ebd014 100644 --- a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageClassificationLocalTest.java +++ b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageClassificationLocalTest.java @@ -73,7 +73,6 @@ public class CvImageClassificationLocalTest extends CamelTestSupport { public void configure() { from("file:src/test/resources/data/mnist?recursive=true&noop=true") .routeId("infer").autoStartup(false) - .convertBodyTo(byte[].class) .to("djl:cv/image_classification?model=MyModel&translator=MyTranslator") .log("${header.CamelFileName} = ${body}") .process(exchange -> { diff --git a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageClassificationTest.java b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageClassificationTest.java index 5d04bfc6302..74d679ff5a6 100644 --- a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageClassificationTest.java +++ b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageClassificationTest.java @@ -42,7 +42,6 @@ public class CvImageClassificationTest extends CamelTestSupport { return new RouteBuilder() { public void configure() { from("file:src/test/resources/data/mnist?recursive=true&noop=true") - .convertBodyTo(byte[].class) .to("djl:cv/image_classification?artifactId=ai.djl.zoo:mlp:0.0.3&showProgress=true") .log("${header.CamelFileName} = ${body.best.className}") .to("mock:result"); diff --git a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageEnhancementLocalTest.java b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageEnhancementLocalTest.java index ed04d1183bc..1292aa3663d 100644 --- a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageEnhancementLocalTest.java +++ b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvImageEnhancementLocalTest.java @@ -71,7 +71,6 @@ public class CvImageEnhancementLocalTest extends CamelTestSupport { public void configure() { from("file:src/test/resources/data/enhance?recursive=true&noop=true") .routeId("image_enhancement").autoStartup(false) - .convertBodyTo(byte[].class) .to("djl:cv/image_enhancement?model=MyModel&translator=MyTranslator") .log("${header.CamelFileName} = ${body}") .process(exchange -> { diff --git a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvInstanceSegmentationTest.java b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvInstanceSegmentationTest.java index 5b57616cc91..19a1333420c 100644 --- a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvInstanceSegmentationTest.java +++ b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvInstanceSegmentationTest.java @@ -41,7 +41,6 @@ public class CvInstanceSegmentationTest extends CamelTestSupport { return new RouteBuilder() { public void configure() { from("file:src/test/resources/data/detect?recursive=true&noop=true") - .convertBodyTo(byte[].class) .to("djl:cv/instance_segmentation?artifactId=ai.djl.mxnet:mask_rcnn:0.0.1") .log("${header.CamelFileName} = ${body}") .to("mock:result"); diff --git a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvObjectDetectionTest.java b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvObjectDetectionTest.java index 6a84201fff4..7c43b852bf7 100644 --- a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvObjectDetectionTest.java +++ b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvObjectDetectionTest.java @@ -42,7 +42,6 @@ public class CvObjectDetectionTest extends CamelTestSupport { return new RouteBuilder() { public void configure() { from("file:src/test/resources/data/detect?recursive=true&noop=true") - .convertBodyTo(byte[].class) .to("djl:cv/object_detection?artifactId=ai.djl.pytorch:ssd:0.0.1") .log("${header.CamelFileName} = ${body}") .to("mock:result"); diff --git a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvPoseEstimationTest.java b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvPoseEstimationTest.java index 5d70f6acae2..b1221bc760b 100644 --- a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvPoseEstimationTest.java +++ b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvPoseEstimationTest.java @@ -17,7 +17,6 @@ package org.apache.camel.component.djl; import ai.djl.modality.cv.Image; -import ai.djl.modality.cv.output.DetectedObjects; import org.apache.camel.builder.RouteBuilder; import org.apache.camel.test.junit5.CamelTestSupport; import org.junit.jupiter.api.BeforeAll; @@ -43,21 +42,11 @@ public class CvPoseEstimationTest extends CamelTestSupport { return new RouteBuilder() { public void configure() { from("file:src/test/resources/data/pose?recursive=true&noop=true") - .convertBodyTo(byte[].class) .to("djl:cv/object_detection?artifactId=ai.djl.mxnet:ssd:0.0.1") .log("${header.CamelFileName} = ${body}") .split(simple("${body.items}")) .filter(simple("${body.className} == 'person'")) - .process(exchange -> { - var obj = exchange.getMessage().getBody(DetectedObjects.DetectedObject.class); - var rect = obj.getBoundingBox().getBounds(); - var image = exchange.getIn().getHeader(DJLConstants.INPUT, Image.class); - exchange.getIn().setBody(image.getSubImage( - (int) (rect.getX() * image.getWidth()), - (int) (rect.getY() * image.getHeight()), - (int) (rect.getWidth() * image.getWidth()), - (int) (rect.getHeight() * image.getHeight()))); - }) + .convertBodyTo(Image.class) .to("djl:cv/pose_estimation?artifactId=ai.djl.mxnet:simple_pose:0.0.1") .log("${header.CamelFileName} = ${body}") .to("mock:result"); diff --git a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvSemanticSegmentationTest.java b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvSemanticSegmentationTest.java index 2035d690320..1489331e4f2 100644 --- a/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvSemanticSegmentationTest.java +++ b/components/camel-ai/camel-djl/src/test/java/org/apache/camel/component/djl/CvSemanticSegmentationTest.java @@ -41,7 +41,6 @@ public class CvSemanticSegmentationTest extends CamelTestSupport { return new RouteBuilder() { public void configure() { from("file:src/test/resources/data/detect?recursive=true&noop=true") - .convertBodyTo(byte[].class) .to("djl:cv/semantic_segmentation?artifactId=ai.djl.pytorch:deeplabv3:0.0.1") .log("${header.CamelFileName} = ${body}") .to("mock:result");