Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added support for Claude 3+ Chat API in Bedrock #2870

Merged
merged 6 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions lib/instrumentation/aws-sdk/v3/bedrock.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,27 +118,34 @@ function recordChatCompletionMessages({
isError: err !== null
})

const msg = new LlmChatCompletionMessage({
agent,
segment,
bedrockCommand,
bedrockResponse,
transaction,
index: 0,
completionId: summary.id
// Record context message(s)
const promptContextMessages = bedrockCommand.prompt
promptContextMessages.forEach((contextMessage, promptIndex) => {
const msg = new LlmChatCompletionMessage({
agent,
segment,
transaction,
bedrockCommand,
content: contextMessage.content,
role: contextMessage.role,
bedrockResponse,
index: promptIndex,
completionId: summary.id
})
recordEvent({ agent, type: 'LlmChatCompletionMessage', msg })
})
recordEvent({ agent, type: 'LlmChatCompletionMessage', msg })

bedrockResponse.completions.forEach((content, index) => {
bedrockResponse.completions.forEach((content, completionIndex) => {
const chatCompletionMessage = new LlmChatCompletionMessage({
agent,
segment,
transaction,
bedrockCommand,
bedrockResponse,
isResponse: true,
index: index + 1,
index: promptContextMessages.length + completionIndex,
content,
role: 'assistant',
completionId: summary.id
})
recordEvent({ agent, type: 'LlmChatCompletionMessage', msg: chatCompletionMessage })
Expand Down Expand Up @@ -179,18 +186,22 @@ function recordEmbeddingMessage({
return
}

const embedding = new LlmEmbedding({
const embeddings = bedrockCommand.prompt.map(prompt => new LlmEmbedding({
agent,
segment,
transaction,
bedrockCommand,
input: prompt.content,
bedrockResponse,
isError: err !== null
}))

embeddings.forEach(embedding => {
recordEvent({ agent, type: 'LlmEmbedding', msg: embedding })
Comment on lines +199 to +200
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to do a bit more research on this but I think some of the Bedrock embedding models allow you to make a single invoke call that generates several embeddings (see this Cohere blog for an example). So I think it might be correct to allow unrolling one embedding command to several embedding events. Currently all the embedding models in the command class produce a single prompt but I'm wondering if the Cohere one is also incorrectly squashing messages in some cases.

I might try to treat that as a separate PR / task if you're alright with that though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's the case, how do you want to handle an error? I assume we still want one error attached to the transaction? I opted to only attach the embedding info if there's one event to keep the current behavior but I'm not sure that's correct

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know this enough, but seems ok for now

})

recordEvent({ agent, type: 'LlmEmbedding', msg: embedding })
if (err) {
const llmError = new LlmError({ bedrockResponse, err, embedding })
const llmError = new LlmError({ bedrockResponse, err, embedding: embeddings.length === 1 ? embeddings[0] : undefined })
agent.errors.add(transaction, err, llmError)
}
}
Expand Down
76 changes: 55 additions & 21 deletions lib/llm-events/aws-bedrock/bedrock-command.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

'use strict'

const { stringifyClaudeChunkedMessage } = require('./utils')

/**
* Parses an AWS invoke command instance into a re-usable entity.
*/
Expand Down Expand Up @@ -68,37 +70,34 @@
/**
* The question posed to the LLM.
*
* @returns {string|string[]|undefined}
* @returns {object[]} The array of context messages passed to the LLM (or a single user prompt for legacy "non-chat" models)
*/
get prompt() {
let result
if (this.isTitan() === true || this.isTitanEmbed() === true) {
result = this.#body.inputText
return [
{
role: 'user',
content: this.#body.inputText
}
]
} else if (this.isCohereEmbed() === true) {
result = this.#body.texts.join(' ')
return [
{
role: 'user',
content: this.#body.texts.join(' ')
}
]
} else if (
this.isClaude() === true ||
this.isClaudeTextCompletionApi() === true ||
this.isAi21() === true ||
this.isCohere() === true ||
this.isLlama() === true
) {
result = this.#body.prompt
} else if (this.isClaude3() === true) {
const collected = []
for (const message of this.#body?.messages) {
if (message?.role === 'assistant') {
continue
}
if (typeof message?.content === 'string') {
collected.push(message?.content)
continue
}
const mappedMsgObj = message?.content.map((msgContent) => msgContent.text)
collected.push(mappedMsgObj)
}
result = collected.join(' ')
return [{ role: 'user', content: this.#body.prompt }]
} else if (this.isClaudeMessagesApi() === true) {
return normalizeClaude3Messages(this.#body?.messages)
}
return result
return []
}

/**
Expand Down Expand Up @@ -151,6 +150,41 @@
isTitanEmbed() {
return this.#modelId.startsWith('amazon.titan-embed')
}

isClaudeMessagesApi() {
return (this.isClaude3() === true || this.isClaude() === true) && 'messages' in this.#body
}

isClaudeTextCompletionApi() {
return this.isClaude() === true && 'prompt' in this.#body
}
}

/**
* Claude v3 requests in Bedrock can have two different "chat" flavors. This function normalizes them into a consistent
* format per the AIM agent spec
*
* @param messages - The raw array of messages passed to the invoke API
* @returns {number|undefined} - The normalized messages
*/
function normalizeClaude3Messages(messages) {
const result = []
for (const message of messages ?? []) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm going to approve this but we can follow up on the missing coverage here: messages is falsey and there is an array of messages but one is null

if (message == null) {
continue
}

Check warning on line 175 in lib/llm-events/aws-bedrock/bedrock-command.js

View check run for this annotation

Codecov / codecov/patch

lib/llm-events/aws-bedrock/bedrock-command.js#L174-L175

Added lines #L174 - L175 were not covered by tests
if (typeof message.content === 'string') {
// Messages can be specified with plain string content
result.push({ role: message.role, content: message.content })
} else if (Array.isArray(message.content)) {
// Or in a "chunked" format for multi-modal support
result.push({
role: message.role,
content: stringifyClaudeChunkedMessage(message.content)
})
}
}
return result
}

module.exports = BedrockCommand
4 changes: 3 additions & 1 deletion lib/llm-events/aws-bedrock/bedrock-response.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

'use strict'

const { stringifyClaudeChunkedMessage } = require('./utils')

/**
* @typedef {object} AwsBedrockMiddlewareResponse
* @property {object} response Has a `body` property that is an IncomingMessage,
Expand Down Expand Up @@ -63,7 +65,7 @@ class BedrockResponse {
// Streamed response
this.#completions = body.completions
} else {
this.#completions = body?.content?.map((c) => c.text)
this.#completions = [stringifyClaudeChunkedMessage(body?.content)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see a versioned test but not a unit test for this

}
this.#id = body.id
} else if (cmd.isCohere() === true) {
Expand Down
17 changes: 4 additions & 13 deletions lib/llm-events/aws-bedrock/chat-completion-message.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,19 @@ class LlmChatCompletionMessage extends LlmEvent {
params = Object.assign({}, defaultParams, params)
super(params)

const { agent, content, isResponse, index, completionId } = params
const { agent, content, isResponse, index, completionId, role } = params
const recordContent = agent.config?.ai_monitoring?.record_content?.enabled
const tokenCB = agent?.llm?.tokenCountCallback

this.is_response = isResponse
this.completion_id = completionId
this.sequence = index
this.content = recordContent === true ? content : undefined
this.role = ''
this.role = role

this.#setId(index)
if (this.is_response === true) {
this.role = 'assistant'
if (typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, content)
}
} else {
this.role = 'user'
this.content = recordContent === true ? this.bedrockCommand.prompt : undefined
if (typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, this.bedrockCommand.prompt)
}
if (typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, content)
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/llm-events/aws-bedrock/chat-completion-summary.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class LlmChatCompletionSummary extends LlmEvent {
const cmd = this.bedrockCommand
this[cfr] = this.bedrockResponse.finishReason
this[rt] = cmd.temperature
this[nm] = 1 + this.bedrockResponse.completions.length
this[nm] = (this.bedrockCommand.prompt.length) + this.bedrockResponse.completions.length
}
}

Expand Down
10 changes: 6 additions & 4 deletions lib/llm-events/aws-bedrock/embedding.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const LlmEvent = require('./event')
/**
* @typedef {object} LlmEmbeddingParams
* @augments LlmEventParams
* @property
* @property {string} input - The input message for the embedding call
*/
/**
* @type {LlmEmbeddingParams}
Expand All @@ -20,16 +20,18 @@ const defaultParams = {}
class LlmEmbedding extends LlmEvent {
constructor(params = defaultParams) {
super(params)
const { agent } = params
const { agent, input } = params
const tokenCb = agent?.llm?.tokenCountCallback

this.input = agent.config?.ai_monitoring?.record_content?.enabled
? this.bedrockCommand.prompt
? input
: undefined
this.error = params.isError
this.duration = params.segment.getDurationInMillis()

// Even if not recording content, we should use the local token counting callback to record token usage
if (typeof tokenCb === 'function') {
this.token_count = tokenCb(this.bedrockCommand.modelId, this.bedrockCommand.prompt)
this.token_count = tokenCb(this.bedrockCommand.modelId, input)
}
}
}
Expand Down
36 changes: 36 additions & 0 deletions lib/llm-events/aws-bedrock/utils.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2024 New Relic Corporation. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

'use strict'

/**
*
* @param {object[]} chunks - The "chunks" that make up a single conceptual message. In a multi-modal scenario, a single message
* might have a number of different-typed chunks interspersed
* @returns {string} - A stringified version of the message. We make a best-effort effort attempt to represent non-text chunks. In the future
* we may want to extend the agent to support these non-text chunks in a richer way. Placeholders are represented in an XML-like format but
* are NOT intended to be parsed as valid XML
*/
function stringifyClaudeChunkedMessage(chunks) {
const stringifiedChunks = chunks.map((msgContent) => {
switch (msgContent.type) {
case 'text':
return msgContent.text
case 'image':
return '<image>'
case 'tool_use':
return `<tool_use>${msgContent.name}</tool_use>`
case 'tool_result':
return `<tool_result>${msgContent.content}</tool_result>`
default:
return '<unknown_chunk>'
}
})
return stringifiedChunks.join('\n\n')
}

module.exports = {
stringifyClaudeChunkedMessage
}
6 changes: 6 additions & 0 deletions test/lib/aws-server-stubs/ai-server/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ function handler(req, res) {
break
}

// Chunked claude model
case 'anthropic.claude-3-5-sonnet-20240620-v1:0': {
response = responses.claude3.get(payload?.messages?.[0]?.content?.[0].text)
break
}

case 'cohere.command-text-v14':
case 'cohere.command-light-text-v14': {
response = responses.cohere.get(payload.prompt)
Expand Down
34 changes: 34 additions & 0 deletions test/lib/aws-server-stubs/ai-server/responses/claude3.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,40 @@ responses.set('text claude3 ultimate question', {
}
})

responses.set('text claude3 ultimate question chunked', {
headers: {
'content-type': contentType,
'x-amzn-requestid': reqId,
'x-amzn-bedrock-invocation-latency': '926',
'x-amzn-bedrock-output-token-count': '36',
'x-amzn-bedrock-input-token-count': '14'
},
statusCode: 200,
body: {
id: 'msg_bdrk_019V7ABaw8ZZZYuRDSTWK7VE',
type: 'message',
role: 'assistant',
model: 'claude-3-haiku-20240307',
stop_sequence: null,
usage: { input_tokens: 30, output_tokens: 265 },
content: [
{
type: 'text',
text: "Here's a nice picture of a 42"
},
{
type: 'image',
source: {
type: 'base64',
media_type: 'image/jpeg',
data: 'U2hoLiBUaGlzIGlzbid0IHJlYWxseSBhbiBpbWFnZQ=='
}
}
],
stop_reason: 'endoftext'
}
})

responses.set('text claude3 ultimate question streamed', {
headers: {
'content-type': 'application/vnd.amazon.eventstream',
Expand Down
Loading
Loading