Skip to content

Commit

Permalink
[Vertex AI] Add ImagenGenerationConfig to generateImages() (#14234)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Dec 9, 2024
1 parent 36a76e2 commit c5472fc
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenAspectRatio {
public static let square1x1 = ImagenAspectRatio(kind: .square1x1)

public static let portrait9x16 = ImagenAspectRatio(kind: .portrait9x16)

public static let landscape16x9 = ImagenAspectRatio(kind: .landscape16x9)

public static let portrait3x4 = ImagenAspectRatio(kind: .portrait3x4)

public static let landscape4x3 = ImagenAspectRatio(kind: .landscape4x3)

let rawValue: String
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenAspectRatio: ProtoEnum {
enum Kind: String {
case square1x1 = "1:1"
case portrait9x16 = "9:16"
case landscape16x9 = "16:9"
case portrait3x4 = "3:4"
case landscape4x3 = "4:3"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenGenerationConfig {
public var numberOfImages: Int?
public var negativePrompt: String?
public var aspectRatio: ImagenAspectRatio?
public var imageFormat: ImagenImageFormat?
public var addWatermark: Bool?

public init(numberOfImages: Int? = nil,
negativePrompt: String? = nil,
aspectRatio: ImagenAspectRatio? = nil,
imageFormat: ImagenImageFormat? = nil,
addWatermark: Bool? = nil) {
self.numberOfImages = numberOfImages
self.negativePrompt = negativePrompt
self.aspectRatio = aspectRatio
self.imageFormat = imageFormat
self.addWatermark = addWatermark
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenImageFormat {
let mimeType: String
let compressionQuality: Int?

public static func png() -> ImagenImageFormat {
return ImagenImageFormat(mimeType: "image/png", compressionQuality: nil)
}

public static func jpeg(compressionQuality: Int? = nil) -> ImagenImageFormat {
return ImagenImageFormat(mimeType: "image/jpeg", compressionQuality: compressionQuality)
}
}
34 changes: 23 additions & 11 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,24 @@ public final class ImagenModel {
self.requestOptions = requestOptions
}

public func generateImages(prompt: String) async throws
public func generateImages(prompt: String,
generationConfig: ImagenGenerationConfig? = nil) async throws
-> ImageGenerationResponse<ImagenInlineDataImage> {
return try await generateImages(
prompt: prompt,
parameters: imageGenerationParameters(storageURI: nil)
parameters: imageGenerationParameters(storageURI: nil, generationConfig: generationConfig)
)
}

public func generateImages(prompt: String, storageURI: String) async throws
public func generateImages(prompt: String, storageURI: String,
generationConfig: ImagenGenerationConfig? = nil) async throws
-> ImageGenerationResponse<ImagenFileDataImage> {
return try await generateImages(
prompt: prompt,
parameters: imageGenerationParameters(storageURI: storageURI)
parameters: imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig
)
)
}

Expand All @@ -74,18 +79,25 @@ public final class ImagenModel {
return try await generativeAIService.loadRequest(request: request)
}

func imageGenerationParameters(storageURI: String?) -> ImageGenerationParameters {
// TODO(#14221): Add support for configuring these parameters.
func imageGenerationParameters(storageURI: String?,
generationConfig: ImagenGenerationConfig? = nil)
-> ImageGenerationParameters {
// TODO(#14221): Add support for configuring remaining parameters.
return ImageGenerationParameters(
sampleCount: 1,
sampleCount: generationConfig?.numberOfImages ?? 1,
storageURI: storageURI,
seed: nil,
negativePrompt: nil,
aspectRatio: nil,
negativePrompt: generationConfig?.negativePrompt,
aspectRatio: generationConfig?.aspectRatio?.rawValue,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: nil,
addWatermark: nil,
outputOptions: generationConfig?.imageFormat.map {
ImageGenerationOutputOptions(
mimeType: $0.mimeType,
compressionQuality: $0.compressionQuality
)
},
addWatermark: generationConfig?.addWatermark,
includeResponsibleAIFilterReason: true
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,22 +249,28 @@ final class IntegrationTests: XCTestCase {
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
"""
var generationConfig = ImagenGenerationConfig()
generationConfig.aspectRatio = .landscape16x9
generationConfig.imageFormat = .jpeg(compressionQuality: 70)

let response = try await imagenModel.generateImages(prompt: imagePrompt)
let response = try await imagenModel.generateImages(
prompt: imagePrompt,
generationConfig: generationConfig
)

XCTAssertNil(response.raiFilteredReason)
XCTAssertEqual(response.images.count, 1)
let image = try XCTUnwrap(response.images.first)
XCTAssertEqual(image.mimeType, "image/png")
XCTAssertEqual(image.mimeType, "image/jpeg")
XCTAssertGreaterThan(image.data.count, 0)
let imagenImage = image.imagenImage
XCTAssertEqual(imagenImage.mimeType, image.mimeType)
XCTAssertEqual(imagenImage.bytesBase64Encoded, image.data.base64EncodedString())
XCTAssertNil(imagenImage.gcsURI)
#if canImport(UIKit)
let uiImage = try XCTUnwrap(UIImage(data: image.data))
XCTAssertEqual(uiImage.size.width, 1024.0)
XCTAssertEqual(uiImage.size.height, 1024.0)
XCTAssertEqual(uiImage.size.width, 1408.0)
XCTAssertEqual(uiImage.size.height, 768.0)
#endif
}
}
Expand Down

0 comments on commit c5472fc

Please sign in to comment.