Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Use struct instead of enum for HarmCategory #13728

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
as input. (#13767)
- [changed] **Breaking Change**: All initializers for `ModelContent` now require
the label `parts: `. (#13832)
- [changed] **Breaking Change**: `HarmCategory` is now a struct instead of an
enum type and the `unknown` case has been removed; in a `switch` statement,
use the `default:` case to cover unknown or unhandled categories. (#13728)
- [changed] The default request timeout is now 180 seconds instead of the
platform-default value of 60 seconds for a `URLRequest`; this timeout may
still be customized in `RequestOptions`. (#13722)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ extension HarmCategory: CustomStringConvertible {
case .harassment: "Harassment"
case .hateSpeech: "Hate speech"
case .sexuallyExplicit: "Sexually explicit"
case .unknown: "Unknown"
case .civicIntegrity: "Civic integrity"
default:
"Unknown HarmCategory: \(rawValue)"
}
}
}
Expand Down
71 changes: 53 additions & 18 deletions FirebaseVertexAI/Sources/Safety.swift
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,65 @@ public struct SafetySetting {
}

/// Categories describing the potential harm a piece of content may pose.
public enum HarmCategory: String, Sendable {
/// Unknown. A new server value that isn't recognized by the SDK.
case unknown = "HARM_CATEGORY_UNKNOWN"
public struct HarmCategory: Sendable, Equatable, Hashable {
enum Kind: String {
case harassment = "HARM_CATEGORY_HARASSMENT"
case hateSpeech = "HARM_CATEGORY_HATE_SPEECH"
case sexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
case dangerousContent = "HARM_CATEGORY_DANGEROUS_CONTENT"
case civicIntegrity = "HARM_CATEGORY_CIVIC_INTEGRITY"
}

/// Harassment content.
case harassment = "HARM_CATEGORY_HARASSMENT"
public static var harassment: HarmCategory {
return self.init(kind: .harassment)
}

/// Negative or harmful comments targeting identity and/or protected attributes.
case hateSpeech = "HARM_CATEGORY_HATE_SPEECH"
public static var hateSpeech: HarmCategory {
return self.init(kind: .hateSpeech)
}

/// Contains references to sexual acts or other lewd content.
case sexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
public static var sexuallyExplicit: HarmCategory {
return self.init(kind: .sexuallyExplicit)
}

/// Promotes or enables access to harmful goods, services, or activities.
case dangerousContent = "HARM_CATEGORY_DANGEROUS_CONTENT"
public static var dangerousContent: HarmCategory {
return self.init(kind: .dangerousContent)
}

/// Content that may be used to harm civic integrity.
public static var civicIntegrity: HarmCategory {
return self.init(kind: .civicIntegrity)
}

/// Returns the raw string representation of the `HarmCategory` value.
///
/// > Note: This value directly corresponds to the values in the
/// > [REST API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/HarmCategory).
public let rawValue: String

init(kind: Kind) {
rawValue = kind.rawValue
}

init(rawValue: String) {
if Kind(rawValue: rawValue) == nil {
VertexLog.error(
code: .generateContentResponseUnrecognizedHarmCategory,
"""
Unrecognized HarmCategory with value "\(rawValue)":
- Check for updates to the SDK as support for "\(rawValue)" may have been added; see \
release notes at https://firebase.google.com/support/release-notes/ios
- Search for "\(rawValue)" in the Firebase Apple SDK Issue Tracker at \
https://github.com/firebase/firebase-ios-sdk/issues and file a Bug Report if none found
"""
)
}
self.rawValue = rawValue
}
}

// MARK: - Codable Conformances
Expand Down Expand Up @@ -139,17 +183,8 @@ extension SafetyRating: Decodable {}
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension HarmCategory: Codable {
public init(from decoder: Decoder) throws {
let value = try decoder.singleValueContainer().decode(String.self)
guard let decodedCategory = HarmCategory(rawValue: value) else {
VertexLog.error(
code: .generateContentResponseUnrecognizedHarmCategory,
"Unrecognized HarmCategory with value \"\(value)\"."
)
self = .unknown
return
}

self = decodedCategory
let rawValue = try decoder.singleValueContainer().decode(String.self)
self = HarmCategory(rawValue: rawValue)
}
}

Expand Down
14 changes: 9 additions & 5 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ final class GenerativeModelTests: XCTestCase {
let expectedSafetyRatings = [
SafetyRating(category: .harassment, probability: .medium),
SafetyRating(category: .dangerousContent, probability: .unknown),
SafetyRating(category: .unknown, probability: .high),
SafetyRating(category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), probability: .high),
]
MockURLProtocol
.requestHandler = try httpRequestHandler(
Expand Down Expand Up @@ -972,18 +972,22 @@ final class GenerativeModelTests: XCTestCase {
forResource: "streaming-success-unknown-safety-enum",
withExtension: "txt"
)
let unknownSafetyRating = SafetyRating(
category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
probability: .unknown
)

var hadUnknown = false
var foundUnknownSafetyRating = false
let stream = try model.generateContentStream("Hi")
for try await content in stream {
XCTAssertNotNil(content.text)
if let ratings = content.candidates.first?.safetyRatings,
ratings.contains(where: { $0.category == .unknown }) {
hadUnknown = true
ratings.contains(where: { $0 == unknownSafetyRating }) {
foundUnknownSafetyRating = true
}
}

XCTAssertTrue(hadUnknown)
XCTAssertTrue(foundUnknownSafetyRating)
}

func testGenerateContentStream_successWithCitations() async throws {
Expand Down
Loading