Skip to content

Commit 8a0c9bb

Browse files
committed
Encode model in GenerateContentRequest only when needed
1 parent eae0721 commit 8a0c9bb

File tree

3 files changed

+90
-20
lines changed

3 files changed

+90
-20
lines changed

Diff for: Sources/GoogleAI/GenerateContentRequest.swift

+28-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ import Foundation
1616

1717
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
1818
struct GenerateContentRequest {
19-
/// Model name.
19+
// Model name.
2020
let model: String
21+
// If true, the `model` field above is encoded in requests; currently only required when nested in
22+
// a `CountTokensRequest`.
23+
let isModelEncoded: Bool
2124
let contents: [ModelContent]
2225
let generationConfig: GenerationConfig?
2326
let safetySettings: [SafetySetting]?
@@ -39,6 +42,30 @@ extension GenerateContentRequest: Encodable {
3942
case toolConfig
4043
case systemInstruction
4144
}
45+
46+
func encode(to encoder: any Encoder) throws {
47+
var container = encoder.container(keyedBy: CodingKeys.self)
48+
49+
if isModelEncoded {
50+
try container.encode(model, forKey: .model)
51+
}
52+
try container.encode(contents, forKey: .contents)
53+
if let generationConfig {
54+
try container.encode(generationConfig, forKey: .generationConfig)
55+
}
56+
if let safetySettings {
57+
try container.encode(safetySettings, forKey: .safetySettings)
58+
}
59+
if let tools {
60+
try container.encode(tools, forKey: .tools)
61+
}
62+
if let toolConfig {
63+
try container.encode(toolConfig, forKey: .toolConfig)
64+
}
65+
if let systemInstruction {
66+
try container.encode(systemInstruction, forKey: .systemInstruction)
67+
}
68+
}
4269
}
4370

4471
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)

Diff for: Sources/GoogleAI/GenerativeModel.swift

+25-18
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,18 @@ public final class GenerativeModel {
175175
-> GenerateContentResponse {
176176
let response: GenerateContentResponse
177177
do {
178-
let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
179-
contents: content(),
180-
generationConfig: generationConfig,
181-
safetySettings: safetySettings,
182-
tools: tools,
183-
toolConfig: toolConfig,
184-
systemInstruction: systemInstruction,
185-
isStreaming: false,
186-
options: requestOptions)
178+
let generateContentRequest = try GenerateContentRequest(
179+
model: modelResourceName,
180+
isModelEncoded: false,
181+
contents: content(),
182+
generationConfig: generationConfig,
183+
safetySettings: safetySettings,
184+
tools: tools,
185+
toolConfig: toolConfig,
186+
systemInstruction: systemInstruction,
187+
isStreaming: false,
188+
options: requestOptions
189+
)
187190
response = try await generativeAIService.loadRequest(request: generateContentRequest)
188191
} catch {
189192
if let imageError = error as? ImageConversionError {
@@ -249,15 +252,18 @@ public final class GenerativeModel {
249252
}
250253
}
251254

252-
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
253-
contents: evaluatedContent,
254-
generationConfig: generationConfig,
255-
safetySettings: safetySettings,
256-
tools: tools,
257-
toolConfig: toolConfig,
258-
systemInstruction: systemInstruction,
259-
isStreaming: true,
260-
options: requestOptions)
255+
let generateContentRequest = GenerateContentRequest(
256+
model: modelResourceName,
257+
isModelEncoded: false,
258+
contents: evaluatedContent,
259+
generationConfig: generationConfig,
260+
safetySettings: safetySettings,
261+
tools: tools,
262+
toolConfig: toolConfig,
263+
systemInstruction: systemInstruction,
264+
isStreaming: true,
265+
options: requestOptions
266+
)
261267

262268
var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
263269
.makeAsyncIterator()
@@ -326,6 +332,7 @@ public final class GenerativeModel {
326332
-> CountTokensResponse {
327333
do {
328334
let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
335+
isModelEncoded: true,
329336
contents: content(),
330337
generationConfig: generationConfig,
331338
safetySettings: safetySettings,

Diff for: Tests/GoogleAITests/GenerateContentRequestTests.swift

+37-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ final class GenerateContentRequestTests: XCTestCase {
3636
let content = [ModelContent(role: role, parts: prompt)]
3737
let request = GenerateContentRequest(
3838
model: modelName,
39+
isModelEncoded: true,
3940
contents: content,
4041
generationConfig: GenerationConfig(temperature: 0.5),
4142
safetySettings: [SafetySetting(
@@ -108,10 +109,11 @@ final class GenerateContentRequestTests: XCTestCase {
108109
""")
109110
}
110111

111-
func testEncodeRequest_optionalFieldsOmitted() throws {
112+
func testEncodeRequest_optionalFieldsOmitted_modelNameEncoded() throws {
112113
let content = [ModelContent(role: role, parts: prompt)]
113114
let request = GenerateContentRequest(
114115
model: modelName,
116+
isModelEncoded: true,
115117
contents: content,
116118
generationConfig: nil,
117119
safetySettings: nil,
@@ -141,4 +143,38 @@ final class GenerateContentRequestTests: XCTestCase {
141143
}
142144
""")
143145
}
146+
147+
func testEncodeRequest_optionalFieldsOmitted_modelNameNotEncoded() throws {
148+
let content = [ModelContent(role: role, parts: prompt)]
149+
let request = GenerateContentRequest(
150+
model: modelName,
151+
isModelEncoded: false,
152+
contents: content,
153+
generationConfig: nil,
154+
safetySettings: nil,
155+
tools: nil,
156+
toolConfig: nil,
157+
systemInstruction: nil,
158+
isStreaming: false,
159+
options: RequestOptions()
160+
)
161+
162+
let jsonData = try encoder.encode(request)
163+
164+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
165+
XCTAssertEqual(json, """
166+
{
167+
"contents" : [
168+
{
169+
"parts" : [
170+
{
171+
"text" : "\(prompt)"
172+
}
173+
],
174+
"role" : "\(role)"
175+
}
176+
]
177+
}
178+
""")
179+
}
144180
}

0 commit comments

Comments
 (0)