assistant.php
2 years ago
base.php
2 years ago
embed.php
2 years ago
function.php
2 years ago
image.php
2 years ago
parameter.php
2 years ago
text.php
2 years ago
transcribe.php
2 years ago
assistant.php
293 lines
| 1 | <?php |
| 2 | |
| 3 | class Meow_MWAI_Query_Assistant extends Meow_MWAI_Query_Base implements JsonSerializable { |
| 4 | public array $messages = []; |
| 5 | public ?string $newMessage = null; |
| 6 | public ?string $newImage = null; |
| 7 | public ?string $newImageData = null; |
| 8 | public ?string $responseFormat = null; |
| 9 | public ?int $promptTokens = null; |
| 10 | public ?int $completionTokens = null; |
| 11 | |
| 12 | public ?string $chatId = null; |
| 13 | public ?string $assistantId = null; |
| 14 | public ?string $threadId = null; |
| 15 | |
| 16 | public function __construct( ?string $prompt = '' ) { |
| 17 | parent::__construct( $prompt ); |
| 18 | $this->mode = "assistant"; |
| 19 | } |
| 20 | |
| 21 | #[\ReturnTypeWillChange] |
| 22 | public function jsonSerialize() { |
| 23 | return [ |
| 24 | 'class' => get_class( $this ), |
| 25 | 'newMessage' => $this->newMessage, |
| 26 | 'newImage' => $this->newImage, |
| 27 | 'model' => $this->model, |
| 28 | 'session' => $this->session, |
| 29 | 'env' => $this->env, |
| 30 | 'envId' => $this->envId, |
| 31 | 'chatId' => $this->chatId, |
| 32 | 'assistantId' => $this->assistantId, |
| 33 | 'threadId' => $this->threadId, |
| 34 | 'service' => $this->service, |
| 35 | ]; |
| 36 | } |
| 37 | |
| 38 | public function getPromptTokens( $refresh = false ): int { |
| 39 | if ( $this->promptTokens && !$refresh ) { |
| 40 | return $this->promptTokens; |
| 41 | } |
| 42 | $this->promptTokens = Meow_MWAI_Core::estimateTokens( $this->messages ); |
| 43 | return $this->promptTokens; |
| 44 | } |
| 45 | |
| 46 | public function getLastPrompt(): string { |
| 47 | if ( empty( $this->messages ) ) { |
| 48 | return $this->prompt; |
| 49 | } |
| 50 | $last = $this->getLastMessage(); |
| 51 | return $last; |
| 52 | } |
| 53 | |
| 54 | /** |
| 55 | * Given a prompt, the model will return one or more predicted completions. |
| 56 | * It can also return the probabilities of alternative tokens at each position. |
| 57 | * @param string $prompt The prompt to generate completions. |
| 58 | */ |
| 59 | public function setPrompt( $prompt ) { |
| 60 | parent::setPrompt( $prompt ); |
| 61 | $this->validateMessages(); |
| 62 | } |
| 63 | |
| 64 | /** |
| 65 | * The type of return expected from the API. It can be either null or "json". |
| 66 | * @param int $maxResults The maximum number of completions. |
| 67 | */ |
| 68 | public function setResponseFormat( $responseFormat ) { |
| 69 | if ( !empty( $responseFormat ) && $responseFormat !== 'json' ) { |
| 70 | throw new Exception( "AI Engine: The response format can only be null or json." ); |
| 71 | } |
| 72 | $this->responseFormat = $responseFormat; |
| 73 | } |
| 74 | |
| 75 | /** |
| 76 | * The prompt is used by models who uses Text Completion (and not Chat Completion). |
| 77 | * This returns the prompt if it's not a chat, otherwise it will build a prompt with |
| 78 | * all the messages nicely formatted. |
| 79 | */ |
| 80 | public function getPrompt(): ?string { |
| 81 | // In the case it's really just a prompt. |
| 82 | if ( count( $this->messages ) === 1 ) { |
| 83 | $first = reset( $this->messages ); |
| 84 | return $first['content']; |
| 85 | } |
| 86 | |
| 87 | // In the case it's a chat that we need to convert into a prompt. |
| 88 | $first = reset( $this->messages ); |
| 89 | $prompt = ""; |
| 90 | if ( $first && $first['role'] === 'system' ) { |
| 91 | $prompt = $first['content'] . "\n\n"; |
| 92 | } |
| 93 | |
| 94 | // Standard Completion |
| 95 | while ( $message = next( $this->messages ) ) { |
| 96 | $role = $message['role']; |
| 97 | $content = $message['content']; |
| 98 | if ( $role === 'system' ) { |
| 99 | $prompt .= "$content\n\n"; |
| 100 | } |
| 101 | if ( $role === 'user' ) { |
| 102 | $prompt .= "User: $content\n"; |
| 103 | } |
| 104 | if ( $role === 'assistant' ) { |
| 105 | $prompt .= "AI: $content\n"; |
| 106 | } |
| 107 | } |
| 108 | $prompt .= "AI: "; |
| 109 | return $prompt; |
| 110 | } |
| 111 | |
| 112 | /** |
| 113 | * Similar to the prompt, but focus on the new/last message. |
| 114 | * Only used when the model has a chat mode (and only used in messages). |
| 115 | * @param string $prompt The messages to generate completions. |
| 116 | */ |
| 117 | public function setNewMessage( string $newMessage ): void { |
| 118 | $this->newMessage = $newMessage; |
| 119 | $this->validateMessages(); |
| 120 | } |
| 121 | |
| 122 | public function setNewImage( string $newImage ): void { |
| 123 | $this->newImage = $newImage; |
| 124 | $this->validateMessages(); |
| 125 | } |
| 126 | |
| 127 | public function setNewImageData( string $newImageData ): void { |
| 128 | $this->newImageData = $newImageData; |
| 129 | $this->validateMessages(); |
| 130 | } |
| 131 | |
| 132 | public function setAssistantId( string $assistantId ): void { |
| 133 | $this->assistantId = $assistantId; |
| 134 | } |
| 135 | |
| 136 | public function setChatId( string $chatId ): void { |
| 137 | $this->chatId = $chatId; |
| 138 | } |
| 139 | |
| 140 | public function setThreadId( string $threadId ): void { |
| 141 | $this->threadId = $threadId; |
| 142 | } |
| 143 | |
| 144 | public function replace( $search, $replace ) { |
| 145 | $this->prompt = str_replace( $search, $replace, $this->prompt ); |
| 146 | $this->validateMessages(); |
| 147 | } |
| 148 | |
| 149 | /** |
| 150 | * Similar to the prompt, but use an array of messages instead. |
| 151 | * @param string $prompt The messages to generate completions. |
| 152 | */ |
| 153 | public function setMessages( array $messages ) { |
| 154 | return; |
| 155 | $messages = array_map( function( $message ) { |
| 156 | if ( is_array( $message ) ) { |
| 157 | return [ 'role' => $message['role'], 'content' => $message['content'] ]; |
| 158 | } |
| 159 | else if ( is_object( $message ) ) { |
| 160 | return [ 'role' => $message->role, 'content' => $message->content ]; |
| 161 | } |
| 162 | else { |
| 163 | throw new InvalidArgumentException( 'Unsupported message type.' ); |
| 164 | } |
| 165 | }, $messages ); |
| 166 | $this->messages = $messages; |
| 167 | $this->validateMessages(); |
| 168 | } |
| 169 | |
| 170 | public function getLastMessage() { |
| 171 | if ( !empty( $this->messages ) ) { |
| 172 | $lastMessageIndex = count( $this->messages ) - 1; |
| 173 | $lastMessage = $this->messages[$lastMessageIndex]; |
| 174 | if ( is_array( $lastMessage['content'] ) ) { |
| 175 | foreach( $lastMessage['content'] as $message ) { |
| 176 | if ( $message['type'] === 'text' ) { |
| 177 | return $message['text']; |
| 178 | } |
| 179 | } |
| 180 | } |
| 181 | else { |
| 182 | return $lastMessage['content']; |
| 183 | } |
| 184 | } |
| 185 | return null; |
| 186 | } |
| 187 | |
| 188 | public function getMessages() { |
| 189 | return $this->messages; |
| 190 | } |
| 191 | |
| 192 | // Function that adds a message just before the last message |
| 193 | public function injectContext( string $content ): void { |
| 194 | if ( !empty( $this->messages ) ) { |
| 195 | $lastMessageIndex = count( $this->messages ) - 1; |
| 196 | $lastMessage = $this->messages[$lastMessageIndex]; |
| 197 | $this->messages[$lastMessageIndex] = [ 'role' => 'system', 'content' => $content ]; |
| 198 | array_push( $this->messages, $lastMessage ); |
| 199 | } |
| 200 | $this->validateMessages(); |
| 201 | } |
| 202 | |
| 203 | private function getImageURL( $image ) { |
| 204 | if ( !empty( $this->newImage ) ) { |
| 205 | return $this->newImage; |
| 206 | } |
| 207 | if ( !empty( $this->newImageData ) ) { |
| 208 | return "data:image/jpeg;base64,{$this->newImageData}"; |
| 209 | } |
| 210 | } |
| 211 | |
| 212 | private function validateMessages(): void { |
| 213 | // Messages should end with either the prompt or, if exists, the newMessage. |
| 214 | $message = empty( $this->newMessage ) ? $this->prompt : $this->newMessage; |
| 215 | $content = $message; |
| 216 | |
| 217 | // If there is an image, we need to adapt it to Vision. |
| 218 | $imageURL = $this->getImageURL( $this->newImage ); |
| 219 | if ( !empty( $imageURL ) ) { |
| 220 | $content = [ |
| 221 | [ "type" => "text", "text" => $message ], |
| 222 | [ "type" => "image_url", "image_url" => [ "url" => $imageURL ] ] |
| 223 | ]; |
| 224 | } |
| 225 | |
| 226 | if ( empty( $this->messages ) ) { |
| 227 | $this->messages = [ [ 'role' => 'user', 'content' => $content ] ]; |
| 228 | } |
| 229 | else { |
| 230 | $last = &$this->messages[ count( $this->messages ) - 1 ]; |
| 231 | if ( $last['role'] === 'user' ) { |
| 232 | $last['content'] = $content; |
| 233 | } |
| 234 | else { |
| 235 | array_push( $this->messages, [ 'role' => 'user', 'content' => $content ] ); |
| 236 | } |
| 237 | } |
| 238 | } |
| 239 | |
| 240 | // Based on the params of the query, update the attributes |
| 241 | public function injectParams( array $params ): void |
| 242 | { |
| 243 | // Those are for the keys passed directly by the shortcode. |
| 244 | $params = $this->convertKeys( $params ); |
| 245 | |
| 246 | if ( !empty( $params['model'] ) ) { |
| 247 | $this->setModel( $params['model'] ); |
| 248 | } |
| 249 | if ( !empty( $params['prompt'] ) ) { |
| 250 | $this->setPrompt( $params['prompt'] ); |
| 251 | } |
| 252 | if ( !empty( $params['messages'] ) ) { |
| 253 | $this->setMessages( $params['messages'] ); |
| 254 | } |
| 255 | if ( !empty( $params['newMessage'] ) ) { |
| 256 | $this->setNewMessage( $params['newMessage'] ); |
| 257 | } |
| 258 | if ( !empty( $params['maxResults'] ) ) { |
| 259 | $this->setMaxResults( $params['maxResults'] ); |
| 260 | } |
| 261 | if ( !empty( $params['env'] ) ) { |
| 262 | $this->setEnv( $params['env'] ); |
| 263 | } |
| 264 | if ( !empty( $params['session'] ) ) { |
| 265 | $this->setSession( $params['session'] ); |
| 266 | } |
| 267 | // Should add the params related to Open AI and Azure |
| 268 | if ( !empty( $params['service'] ) ) { |
| 269 | $this->setService( $params['service'] ); |
| 270 | } |
| 271 | if ( !empty( $params['apiKey'] ) ) { |
| 272 | $this->setApiKey( $params['apiKey'] ); |
| 273 | } |
| 274 | if ( !empty( $params['botId'] ) ) { |
| 275 | $this->setBotId( $params['botId'] ); |
| 276 | } |
| 277 | if ( !empty( $params['envId'] ) ) { |
| 278 | $this->setEnvId( $params['envId'] ); |
| 279 | } |
| 280 | if ( !empty( $params['chatId'] ) ) { |
| 281 | $this->setChatId( $params['chatId'] ); |
| 282 | } |
| 283 | if ( !empty( $params['assistantId'] ) ) { |
| 284 | $this->setAssistantId( $params['assistantId'] ); |
| 285 | } |
| 286 | if ( !empty( $params['threadId'] ) ) { |
| 287 | $this->setThreadId( $params['threadId'] ); |
| 288 | } |
| 289 | if ( !empty( $params['responseFormat'] ) ) { |
| 290 | $this->setResponseFormat( $params['responseFormat'] ); |
| 291 | } |
| 292 | } |
| 293 | } |