Skip to content

Commit

Permalink
code snippet provider
Browse files Browse the repository at this point in the history
  • Loading branch information
lukka committed Dec 23, 2024
1 parent 0aaae1f commit d1876a2
Show file tree
Hide file tree
Showing 8 changed files with 1,101 additions and 725 deletions.
1 change: 1 addition & 0 deletions Extension/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6573,6 +6573,7 @@
"xml2js": "^0.6.2"
},
"dependencies": {
"@github/copilot-language-server": "^1.253.0",
"@vscode/extension-telemetry": "^0.9.6",
"chokidar": "^3.6.0",
"comment-json": "^4.2.3",
Expand Down
27 changes: 26 additions & 1 deletion Extension/src/LanguageServer/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ import {
} from './codeAnalysis';
import { Location, TextEdit, WorkspaceEdit } from './commonTypes';
import * as configs from './configurations';
import { CopilotCompletionContextProvider } from './copilotCompletionContextProvider';
import { DataBinding } from './dataBinding';
import { cachedEditorConfigSettings, getEditorConfigSettings } from './editorConfig';
import { CppSourceStr, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension';
import { CppSourceStr, SnippetEntry, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension';
import { LocalizeStringParams, getLocaleId, getLocalizedString } from './localization';
import { PersistentFolderState, PersistentState, PersistentWorkspaceState } from './persistentState';
import { RequestCancelled, ServerCancelled, createProtocolFilter } from './protocolFilter';
Expand Down Expand Up @@ -575,6 +576,16 @@ interface FilesEncodingChanged {
foldersFilesEncoding: FolderFilesEncodingChanged[];
}

export interface CompletionContextResult {
snippets: SnippetEntry[];
translationUnitUri: string;
}

export interface CompletionContextParams {
file: string;
caretOffset: number;
}

// Requests
const PreInitializationRequest: RequestType<void, string, void> = new RequestType<void, string, void>('cpptools/preinitialize');
const InitializationRequest: RequestType<CppInitializationParams, void, void> = new RequestType<CppInitializationParams, void, void>('cpptools/initialize');
Expand All @@ -597,6 +608,7 @@ const ChangeCppPropertiesRequest: RequestType<CppPropertiesParams, void, void> =
const IncludesRequest: RequestType<GetIncludesParams, GetIncludesResult, void> = new RequestType<GetIncludesParams, GetIncludesResult, void>('cpptools/getIncludes');
const CppContextRequest: RequestType<TextDocumentIdentifier, ChatContextResult, void> = new RequestType<TextDocumentIdentifier, ChatContextResult, void>('cpptools/getChatContext');
const ProjectContextRequest: RequestType<TextDocumentIdentifier, ProjectContextResult, void> = new RequestType<TextDocumentIdentifier, ProjectContextResult, void>('cpptools/getProjectContext');
const CompletionContextRequest: RequestType<CompletionContextParams, CompletionContextResult, void> = new RequestType<CompletionContextParams, CompletionContextResult, void>('cpptools/getCompletionContext');

// Notifications to the server
const DidOpenNotification: NotificationType<DidOpenTextDocumentParams> = new NotificationType<DidOpenTextDocumentParams>('textDocument/didOpen');
Expand Down Expand Up @@ -832,6 +844,7 @@ export interface Client {
getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise<ChatContextResult>;
getProjectContext(uri: vscode.Uri): Promise<ProjectContextResult>;
filesEncodingChanged(filesEncodingChanged: FilesEncodingChanged): void;
getCompletionContext(fileName: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextResult>;
}

export function createClient(workspaceFolder?: vscode.WorkspaceFolder): Client {
Expand Down Expand Up @@ -866,6 +879,7 @@ export class DefaultClient implements Client {
private configurationProvider?: string;
private hoverProvider: HoverProvider | undefined;
private copilotHoverProvider: CopilotHoverProvider | undefined;
private copilotCompletionProvider?: CopilotCompletionContextProvider;

public lastCustomBrowseConfiguration: PersistentFolderState<WorkspaceBrowseConfiguration | undefined> | undefined;
public lastCustomBrowseConfigurationProviderId: PersistentFolderState<string | undefined> | undefined;
Expand Down Expand Up @@ -1333,6 +1347,9 @@ export class DefaultClient implements Client {
this.semanticTokensProviderDisposable = vscode.languages.registerDocumentSemanticTokensProvider(util.documentSelector, this.semanticTokensProvider, semanticTokensLegend);
}

this.copilotCompletionProvider = await CopilotCompletionContextProvider.Create();
this.disposables.push(this.copilotCompletionProvider);

// Listen for messages from the language server.
this.registerNotifications();

Expand Down Expand Up @@ -1864,6 +1881,7 @@ export class DefaultClient implements Client {
if (diagnosticsCollectionIntelliSense) {
diagnosticsCollectionIntelliSense.delete(document.uri);
}
this.copilotCompletionProvider?.removeFile(uri);
openFileVersions.delete(uri);
}

Expand Down Expand Up @@ -2312,6 +2330,12 @@ export class DefaultClient implements Client {
() => this.languageClient.sendRequest(CppContextRequest, params, token), token);
}

public async getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextResult> {
await withCancellation(this.ready, token);
return DefaultClient.withLspCancellationHandling(
() => this.languageClient.sendRequest(CompletionContextRequest, { file: file.toString(), caretOffset }, token), token);
}

/**
* a Promise that can be awaited to know when it's ok to proceed.
*
Expand Down Expand Up @@ -4240,4 +4264,5 @@ class NullClient implements Client {
getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise<ChatContextResult> { return Promise.resolve({} as ChatContextResult); }
getProjectContext(uri: vscode.Uri): Promise<ProjectContextResult> { return Promise.resolve({} as ProjectContextResult); }
filesEncodingChanged(filesEncodingChanged: FilesEncodingChanged): void { }
getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextResult> { return Promise.resolve({} as CompletionContextResult); }
}
229 changes: 229 additions & 0 deletions Extension/src/LanguageServer/copilotCompletionContextProvider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
/* --------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All Rights Reserved.
* See 'LICENSE' in the project root for license information.
* ------------------------------------------------------------------------------------------ */
import { CodeSnippet, ContextResolver, ResolveRequest } from '@github/copilot-language-server';
import * as vscode from 'vscode';
import { DocumentSelector } from 'vscode-languageserver-protocol';
import { getOutputChannelLogger, Logger } from '../logger';
import * as telemetry from '../telemetry';
import { CompletionContextResult } from './client';
import { CopilotCompletionContextTelemetry } from './copilotCompletionContextTelemetry';
import { getCopilotApi } from './copilotProviders';
import { clients } from './extension';

class DefaultValueFallback extends Error {
static readonly DefaultValue = "DefaultValue";
constructor() { super(DefaultValueFallback.DefaultValue); }
}

class CancellationError extends Error {
static readonly Canceled = "Canceled";
constructor() { super(CancellationError.Canceled); }
}

class CopilotContextProviderException extends Error {
}

class WellKnownErrors extends Error {
static readonly ClientNotFound = "ClientNotFound";
private constructor(message: string) { super(message); }
public static clientNotFound(): Error {
return new WellKnownErrors(WellKnownErrors.ClientNotFound);
}
}

// Mutually exclusive values for the kind of returned completion context. They either are:
// - computed.
// - obtained from the cache.
// - missing since the computation took too long and no cache is present (cache miss). The value
// is asynchronously computed and stored in cache.
// - the token is signaled as cancelled, in which case all the operations are aborted.
// - an unknown state.
enum CopilotCompletionKind {
Computed = 'computed',
GotFromCache = 'gotFromCacheHit',
MissingCacheMiss = 'missingCacheMiss',
Canceled = 'canceled',
Unknown = 'unknown'
}

export class CopilotCompletionContextProvider implements ContextResolver<CodeSnippet> {
private static readonly providerId = 'cppTools';
private readonly completionContextCache: Map<string, CompletionContextResult> = new Map<string, CompletionContextResult>();
private static readonly defaultCppDocumentSelector: DocumentSelector = [{ language: 'cpp' }, { language: 'c' }, { language: 'cuda-cpp' }];
private static readonly defaultTimeBudgetFactor: number = 0.5;
private completionContextCancellation = new vscode.CancellationTokenSource();
private contextProviderDisposable: vscode.Disposable | undefined;

private async waitForCompletionWithTimeoutAndCancellation<T>(promise: Promise<T>, defaultValue: T | undefined,
timeout: number, token: vscode.CancellationToken): Promise<[T | undefined, CopilotCompletionKind]> {
const defaultValuePromise = new Promise<T>((_resolve, reject) => setTimeout(() => {
if (token.isCancellationRequested) {
reject(new CancellationError());
} else {
reject(new DefaultValueFallback());
}
}, timeout));
const cancellationPromise = new Promise<T>((_, reject) => {
token.onCancellationRequested(() => {
reject(new CancellationError());
});
});
let snippetsOrNothing: T | undefined;
try {
snippetsOrNothing = await Promise.race([promise, cancellationPromise, defaultValuePromise]);
} catch (e) {
if (e instanceof DefaultValueFallback) {
return [defaultValue, defaultValue !== undefined ? CopilotCompletionKind.GotFromCache : CopilotCompletionKind.MissingCacheMiss];
} else if (e instanceof CancellationError) {
return [undefined, CopilotCompletionKind.Canceled];
} else {
throw e;
}
}

return [snippetsOrNothing, CopilotCompletionKind.Computed];
}

// Get the completion context with a timeout and a cancellation token.
// The cancellationToken indicates that the value should not be returned nor cached.
private async getCompletionContextWithCancellation(documentUri: string, caretOffset: number,
startTime: number, out: Logger, telemetry: CopilotCompletionContextTelemetry, token: vscode.CancellationToken):
Promise<CompletionContextResult | undefined> {
try {
const docUri = vscode.Uri.parse(documentUri);
const client = clients.getClientFor(docUri);
if (!client) { throw WellKnownErrors.clientNotFound(); }
const getContextStartTime = performance.now();
const completionContext = await client.getCompletionContext(docUri, caretOffset, token);

if (completionContext.translationUnitUri !== docUri.toString()) {
out.appendLine(`Copilot: getCompletionContextWithCancellation(${docUri}:${caretOffset}): translation unit URI mismatch: ${completionContext.translationUnitUri} vs ${docUri.toString()}`);
}

const copilotCompletionContext = completionContext;
this.completionContextCache.set(completionContext.translationUnitUri, copilotCompletionContext);
const duration = CopilotCompletionContextProvider.getRoundedDuration(startTime);
out.appendLine(`Copilot: getCompletionContextWithCancellation(${docUri}:${caretOffset}): from ${completionContext.translationUnitUri} cached ${completionContext.snippets.length} snippets in [ms]: ${duration}`);
telemetry.addSnippetCount(completionContext.snippets.length);
telemetry.addCacheComputedElapsed(duration);
telemetry.addComputeContextElapsed(CopilotCompletionContextProvider.getRoundedDuration(getContextStartTime));
return copilotCompletionContext;
} catch (e) {
if (e instanceof CancellationError) {
telemetry.addInternalCanceled(CopilotCompletionContextProvider.getRoundedDuration(startTime));
throw e;
} else if (e instanceof vscode.CancellationError || (e as Error)?.message === CancellationError.Canceled) {
telemetry.addCopilotCanceled(CopilotCompletionContextProvider.getRoundedDuration(startTime));
throw e;
}

if (e instanceof WellKnownErrors) {
telemetry.addWellKnownError(e.message);
}

const err = e as Error;
out.appendLine(`Copilot: getCompletionContextWithCancellation(${documentUri}:${caretOffset}): Error: '${err?.message}', stack '${err?.stack}`);
telemetry.addError();
return undefined;
} finally {
telemetry.file();
}
}

private async fetchTimeBudgetFactor(context: ResolveRequest): Promise<number> {
const budgetFactor = context.activeExperiments.get("CppToolsCopilotTimeBudget");
return (budgetFactor as number) !== undefined ? budgetFactor as number : CopilotCompletionContextProvider.defaultTimeBudgetFactor;
}

private static getRoundedDuration(startTime: number): number {
return Math.round(performance.now() - startTime);
}

public static async Create() {
const copilotCompletionProvider = new CopilotCompletionContextProvider();
await copilotCompletionProvider.registerCopilotContextProvider();
return copilotCompletionProvider;
}

public dispose(): void {
this.completionContextCancellation.cancel();
this.contextProviderDisposable?.dispose();
}

public removeFile(fileUri: string): void {
this.completionContextCache.delete(fileUri);
}

public async resolve(context: ResolveRequest, copilotCancel: vscode.CancellationToken): Promise<CodeSnippet[]> {
const resolveStartTime = performance.now();
const out: Logger = getOutputChannelLogger();
const timeBudgetFactor = await this.fetchTimeBudgetFactor(context);
const telemetry = new CopilotCompletionContextTelemetry();
let copilotCompletionContext: CompletionContextResult | undefined;
let copilotCompletionContextKind: CopilotCompletionKind = CopilotCompletionKind.Unknown;
try {
this.completionContextCancellation.cancel();
this.completionContextCancellation = new vscode.CancellationTokenSource();
const docUri = context.documentContext.uri;
const cachedValue: CompletionContextResult | undefined = this.completionContextCache.get(docUri.toString());
const computeSnippetsPromise = this.getCompletionContextWithCancellation(docUri,
context.documentContext.offset, resolveStartTime, out, telemetry.fork(), this.completionContextCancellation.token);
[copilotCompletionContext, copilotCompletionContextKind] = await this.waitForCompletionWithTimeoutAndCancellation(
computeSnippetsPromise, cachedValue, context.timeBudget * timeBudgetFactor, copilotCancel);
if (copilotCompletionContextKind === CopilotCompletionKind.Canceled) {
const duration: number = CopilotCompletionContextProvider.getRoundedDuration(resolveStartTime);
out.appendLine(`Copilot: getCompletionContext(${context.documentContext.uri}:${context.documentContext.offset}): cancelled, elapsed time (ms) : ${duration}`);
telemetry.addInternalCanceled(duration);
throw new CancellationError();
}
telemetry.addSnippetCount(copilotCompletionContext?.snippets?.length);
return copilotCompletionContext?.snippets ?? [];
} catch (e: any) {
if (e instanceof CancellationError) {
throw e;
}

// For any other exception's type, it is an error.
telemetry.addError();
throw e;
} finally {
telemetry.addKind(copilotCompletionContextKind.toString());
const duration: number = CopilotCompletionContextProvider.getRoundedDuration(resolveStartTime);
if (copilotCompletionContext === undefined) {
out.appendLine(`Copilot: getCompletionContext(${context.documentContext.uri}:${context.documentContext.offset}): no snippets provided (${copilotCompletionContextKind.toString()}), elapsed time (ms): ${duration}`);
} else {
const uri = copilotCompletionContext.translationUnitUri ?? "<undefined-uri>";
out.appendLine(`Copilot: getCompletionContext(${context.documentContext.uri}:${context.documentContext.offset}): for ${uri} provided ${copilotCompletionContext.snippets?.length} snippets (${copilotCompletionContextKind.toString()}), elapsed time (ms): ${duration}`);
}
telemetry.addResolvedElapsed(duration);
telemetry.addCacheSize(this.completionContextCache.size);
telemetry.file();
}
}

public async registerCopilotContextProvider(): Promise<void> {
try {
const isCustomSnippetProviderApiEnabled = await telemetry.isExperimentEnabled("CppToolsCustomSnippetsApi");
if (isCustomSnippetProviderApiEnabled) {
const copilotApi = await getCopilotApi();
if (!copilotApi) { throw new CopilotContextProviderException("getCopilotApi() returned null."); }
const contextAPI = await copilotApi.getContextProviderAPI("v1");
if (!contextAPI) { throw new CopilotContextProviderException("getContextProviderAPI(v1) returned null."); }
this.contextProviderDisposable = contextAPI.registerContextProvider({
id: CopilotCompletionContextProvider.providerId,
selector: CopilotCompletionContextProvider.defaultCppDocumentSelector,
resolver: this
});
}
} catch (e) {
console.warn("Failed to register the Copilot Context Provider.");
let msg = "Failed to register the Copilot Context Provider";
if (e instanceof CopilotContextProviderException) {
msg = `${msg}: ${e.message}`;
}
telemetry.logCopilotEvent("registerCopilotContextProviderError", { "message": msg });
}
}
}
Loading

0 comments on commit d1876a2

Please sign in to comment.