Ma77Ball commented on code in PR #5124:
URL: https://github.com/apache/texera/pull/5124#discussion_r3310294213


##########
amber/src/main/scala/org/apache/texera/web/resource/HuggingFaceModelResource.scala:
##########
@@ -0,0 +1,504 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.texera.web.resource
+
+import com.fasterxml.jackson.core.`type`.TypeReference
+import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
+import kong.unirest.Unirest
+
+import javax.ws.rs._
+import javax.ws.rs.core.{MediaType, Response}
+import java.nio.file.{Files, Path => NioPath, Paths}
+import java.util.concurrent.ConcurrentHashMap
+import java.util.stream.Collectors
+import scala.jdk.CollectionConverters._
+
+/**
+  * REST resource that proxies the Hugging Face Hub API to list
+  * models for the HuggingFace operator.
+  *
+  * Browse mode:  GET /api/huggingface/models?task=text-generation
+  *   Fetches ALL models for the task from HF Hub (paginated internally),
+  *   caches the full list server-side, and returns it.
+  *
+  * Search mode:  GET /api/huggingface/models?task=text-generation&search=bert
+  *   Forwards the search query to HF Hub API (searches all models).
+  */
+@Path("/huggingface")
+@Produces(Array(MediaType.APPLICATION_JSON))
+class HuggingFaceModelResource {
+
+  import HuggingFaceModelResource._
+
+  @GET
+  @Path("/models")
+  def listModels(
+      @QueryParam("task") @DefaultValue("text-generation") task: String,
+      @QueryParam("search") search: String
+  ): Response = {
+    try {
+      val hfToken = Option(System.getenv("HF_TOKEN")).getOrElse("")
+
+      // ── Search mode: forward query to HF Hub, return results directly ──
+      if (search != null && search.trim.nonEmpty) {
+        return fetchSearchResults(task, search.trim, hfToken)
+      }
+
+      // ── Browse mode: return ALL models for this task (cached) ──
+      val cached = modelCache.get(task)
+      if (cached != null) {
+        return Response.ok(cached).build()
+      }
+
+      // Not cached — fetch all pages from HF Hub API
+      val allModels = fetchAllModelsForTask(task, hfToken)
+      val json = objectMapper.writeValueAsString(allModels)
+      modelCache.put(task, json)
+
+      Response.ok(json).build()
+
+    } catch {
+      case e: Exception =>
+        Response
+          .status(Response.Status.INTERNAL_SERVER_ERROR)
+          .entity(s"""{"error":"Failed to fetch models: ${e.getMessage}"}""")
+          .build()
+    }
+  }
+
+  @POST
+  @Path("/upload-audio")
+  @Consumes(Array(MediaType.WILDCARD))
+  def uploadAudioReference(
+      @QueryParam("filename") filename: String,
+      bytes: Array[Byte]
+  ): Response = {
+    try {
+      if (bytes == null || bytes.isEmpty) {
+        return Response
+          .status(Response.Status.BAD_REQUEST)
+          .entity("""{"error":"Audio payload is empty."}""")
+          .build()
+      }
+
+      val safeFileName = Option(filename)
+        .map(_.trim)
+        .filter(_.nonEmpty)
+        .map(name => Paths.get(name).getFileName.toString)
+        .getOrElse("audio.bin")
+      val extension = {
+        val idx = safeFileName.lastIndexOf('.')
+        if (idx >= 0 && idx < safeFileName.length - 1) 
safeFileName.substring(idx) else ".bin"
+      }
+
+      val tempDir = Paths.get(System.getProperty("java.io.tmpdir"), 
"texera-hf-audio")
+      Files.createDirectories(tempDir)
+      val tempFile: NioPath = Files.createTempFile(tempDir, "hf-audio-", 
extension)
+      Files.write(tempFile, bytes)
+
+      val json = objectMapper.writeValueAsString(
+        Map(
+          "path" -> tempFile.toAbsolutePath.toString,
+          "fileName" -> safeFileName
+        ).asJava
+      )
+      Response.ok(json).build()
+    } catch {
+      case e: Exception =>
+        Response
+          .status(Response.Status.INTERNAL_SERVER_ERROR)
+          .entity(s"""{"error":"Failed to upload audio: ${e.getMessage}"}""")
+          .build()
+    }
+  }
+
+  @GET
+  @Path("/audio-preview")
+  def previewUploadedAudio(@QueryParam("path") path: String): Response = {
+    try {
+      val trimmedPath = Option(path).map(_.trim).getOrElse("")
+      if (trimmedPath.isEmpty) {
+        return Response
+          .status(Response.Status.BAD_REQUEST)
+          .entity("""{"error":"Audio path is required."}""")
+          .build()
+      }
+
+      val tempDir = Paths
+        .get(System.getProperty("java.io.tmpdir"), "texera-hf-audio")
+        .toAbsolutePath
+        .normalize()
+      val requestedPath = Paths.get(trimmedPath).toAbsolutePath.normalize()
+      if (!requestedPath.startsWith(tempDir)) {
+        return Response
+          .status(Response.Status.FORBIDDEN)
+          .entity("""{"error":"Audio path is outside the allowed preview 
directory."}""")
+          .build()
+      }
+      if (!Files.exists(requestedPath) || !Files.isRegularFile(requestedPath)) 
{
+        return Response
+          .status(Response.Status.NOT_FOUND)
+          .entity("""{"error":"Uploaded audio file was not found."}""")
+          .build()
+      }
+
+      val contentType = Option(Files.probeContentType(requestedPath))
+        .filter(_.trim.nonEmpty)
+        .getOrElse(inferAudioContentType(requestedPath))
+      Response.ok(Files.readAllBytes(requestedPath), contentType).build()
+    } catch {
+      case e: Exception =>
+        Response
+          .status(Response.Status.INTERNAL_SERVER_ERROR)
+          .entity(s"""{"error":"Failed to read uploaded audio: 
${e.getMessage}"}""")
+          .build()
+    }
+  }
+
+  @GET
+  @Path("/media-proxy")
+  def proxyRemoteMedia(@QueryParam("url") url: String): Response = {
+    try {
+      val trimmedUrl = Option(url).map(_.trim).getOrElse("")
+      if (trimmedUrl.isEmpty) {
+        return Response
+          .status(Response.Status.BAD_REQUEST)
+          .entity("""{"error":"Media URL is required."}""")
+          .build()
+      }
+      if (!trimmedUrl.startsWith("http://";) && 
!trimmedUrl.startsWith("https://";)) {
+        return Response
+          .status(Response.Status.BAD_REQUEST)
+          .entity("""{"error":"Only http(s) media URLs are supported."}""")
+          .build()
+      }
+
+      val upstreamResponse = Unirest
+        .get(trimmedUrl)
+        .connectTimeout(10000)
+        .socketTimeout(120000)
+        .asBytes()
+
+      if (upstreamResponse.getStatus != 200) {
+        return Response
+          .status(upstreamResponse.getStatus)
+          .entity(
+            s"""{"error":"Failed to fetch remote media: 
${upstreamResponse.getStatusText}"}"""
+          )
+          .build()
+      }
+
+      val contentType = 
Option(upstreamResponse.getHeaders.getFirst("Content-Type"))
+        .filter(_.trim.nonEmpty)
+        .getOrElse(MediaType.APPLICATION_OCTET_STREAM)
+      Response.ok(upstreamResponse.getBody, contentType).build()
+    } catch {
+      case e: Exception =>
+        Response
+          .status(Response.Status.INTERNAL_SERVER_ERROR)
+          .entity(s"""{"error":"Failed to proxy remote media: 
${e.getMessage}"}""")
+          .build()
+    }
+  }
+
+  /** Search HF Hub for models matching a query within a task. */
+  private def fetchSearchResults(task: String, query: String, hfToken: 
String): Response = {
+    var request = Unirest
+      .get("https://huggingface.co/api/models";)
+      .queryString("pipeline_tag", task)
+      .queryString("sort", "downloads")
+      .queryString("direction", "-1")
+      .queryString("limit", "100")
+      .queryString("filter", task)
+      .queryString("inference", "warm")
+      .queryString("search", query)
+
+    if (hfToken.nonEmpty) {
+      request = request.header("Authorization", s"Bearer $hfToken")
+    }
+
+    val hfResponse = request.asString()
+
+    if (hfResponse.getStatus != 200) {
+      return Response
+        .status(hfResponse.getStatus)
+        .entity(s"""{"error":"Hugging Face API error: 
${hfResponse.getStatusText}"}""")
+        .build()
+    }
+
+    val rawModels = objectMapper.readValue(hfResponse.getBody, listOfMapsType)
+    val out = buildSimplifiedList(rawModels)
+    Response.ok(objectMapper.writeValueAsString(out)).build()
+  }
+
+  /**
+    * Fetch pipeline task tags from the Hugging Face Hub API.
+    * GET /api/huggingface/tasks
+    *
+    * Returns a JSON array of objects: [{ "tag": "text-generation", "label": 
"Text Generation" }, ...]
+    * The result is cached server-side for the lifetime of the process.
+    */
+  @GET
+  @Path("/tasks")
+  def listTasks(): Response = {
+    try {
+      val cached = taskCache.get("all")
+      if (cached != null) {
+        return Response.ok(cached).build()
+      }
+
+      val hfToken = Option(System.getenv("HF_TOKEN")).getOrElse("")
+      var request = Unirest
+        .get("https://huggingface.co/api/tasks";)
+        .connectTimeout(10000)
+        .socketTimeout(15000)
+
+      if (hfToken.nonEmpty) {
+        request = request.header("Authorization", s"Bearer $hfToken")
+      }
+
+      val hfResponse = request.asString()
+
+      if (hfResponse.getStatus != 200) {
+        return Response
+          .status(hfResponse.getStatus)
+          .entity(s"""{"error":"Hugging Face API error: 
${hfResponse.getStatusText}"}""")
+          .build()
+      }
+
+      // /api/tasks returns a JSON object: { "<pipeline_tag>": { "label": 
"...", ... }, ... }
+      // Using readTree so no entry is dropped regardless of its value type 
(null, array, etc.)
+      val root: JsonNode = objectMapper.readTree(hfResponse.getBody)
+      val taskList = new java.util.ArrayList[java.util.Map[String, Object]]()
+      val iter = root.fields()
+      while (iter.hasNext) {
+        val entry = iter.next()
+        val tag = entry.getKey
+        val info: JsonNode = entry.getValue
+        val label =
+          if (info != null && info.isObject && info.has("label")) 
info.get("label").asText(tag)
+          else tag
+        val taskEntry = new java.util.LinkedHashMap[String, Object]()
+        taskEntry.put("tag", tag)
+        taskEntry.put("label", label)
+        taskList.add(taskEntry)
+      }
+
+      // Filter out tasks that have no models available with hosted inference
+      val availableTasks = taskList
+        .parallelStream()
+        .filter(task => hasModelsForTask(task.get("tag").toString, hfToken))
+        .collect(Collectors.toList())

Review Comment:
   A follow-up will be fine 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to