diff --git a/packages/delegate/src/mergeFields.ts b/packages/delegate/src/mergeFields.ts index ef479d45d43..681ef3b3477 100644 --- a/packages/delegate/src/mergeFields.ts +++ b/packages/delegate/src/mergeFields.ts @@ -117,12 +117,16 @@ function handleResolverResult( const type = schema.getType(object.__typename) as GraphQLObjectType; const { fields } = collectFields(schema, EMPTY_OBJECT, EMPTY_OBJECT, type, selectionSet); const nullResult: Record = {}; - for (const [responseKey, fieldNodes] of fields) { + for (const [responseKey, fieldGroups] of fields) { const combinedPath = [...path, responseKey]; if (resolverResult instanceof GraphQLError) { nullResult[responseKey] = relocatedError(resolverResult, combinedPath); } else if (resolverResult instanceof Error) { - nullResult[responseKey] = locatedError(resolverResult, fieldNodes, combinedPath); + nullResult[responseKey] = locatedError( + resolverResult, + fieldGroups.map(group => group.fieldNode), + combinedPath, + ); } else { nullResult[responseKey] = null; } diff --git a/packages/executor/src/execution/__tests__/variables-test.ts b/packages/executor/src/execution/__tests__/variables-test.ts index 47d84b0671f..939ecd83be2 100644 --- a/packages/executor/src/execution/__tests__/variables-test.ts +++ b/packages/executor/src/execution/__tests__/variables-test.ts @@ -1,4 +1,3 @@ -// eslint-disable-next-line import/no-extraneous-dependencies import { inspect } from 'cross-inspect'; import { GraphQLArgumentConfig, @@ -30,6 +29,13 @@ const TestComplexScalar = new GraphQLScalarType({ }, }); +const NestedType: GraphQLObjectType = new GraphQLObjectType({ + name: 'NestedType', + fields: { + echo: fieldWithInputArg({ type: GraphQLString }), + }, +}); + const TestInputObject = new GraphQLInputObjectType({ name: 'TestInputObject', fields: { @@ -98,6 +104,10 @@ const TestType = new GraphQLObjectType({ defaultValue: 'Hello World', }), list: fieldWithInputArg({ type: new GraphQLList(GraphQLString) }), + nested: { + type: NestedType, + resolve: () => ({}), + }, nnList: fieldWithInputArg({ type: new GraphQLNonNull(new GraphQLList(GraphQLString)), }), @@ -117,6 +127,15 @@ function executeQuery(query: string, variableValues?: { [variable: string]: unkn return executeSync({ schema, document, variableValues }); } +function executeQueryWithFragmentArguments( + query: string, + variableValues?: { [variable: string]: unknown }, +) { + // TODO: figure out how to do custom parser here + const document = parse(query, { experimentalFragmentArguments: true }); + return executeSync({ schema, document, variableValues }); +} + describe('Execute: Handles inputs', () => { describe('Handles objects and nullability', () => { describe('using inline structs', () => { @@ -1038,4 +1057,277 @@ describe('Execute: Handles inputs', () => { }); }); }); + + describe('using fragment arguments', () => { + it('when there are no fragment arguments', () => { + const result = executeQueryWithFragmentArguments(` + query { + ...a + } + fragment a on TestType { + fieldWithNonNullableStringInput(input: "A") + } + `); + expectJSON(result).toDeepEqual({ + data: { + fieldWithNonNullableStringInput: '"A"', + }, + }); + }); + + it('when a value is required and provided', () => { + const result = executeQueryWithFragmentArguments(` + query { + ...a(value: "A") + } + fragment a($value: String!) on TestType { + fieldWithNonNullableStringInput(input: $value) + } + `); + expectJSON(result).toDeepEqual({ + data: { + fieldWithNonNullableStringInput: '"A"', + }, + }); + }); + + it('when a value is required and not provided', () => { + const result = executeQueryWithFragmentArguments(` + query { + ...a + } + fragment a($value: String!) on TestType { + fieldWithNullableStringInput(input: $value) + } + `); + + expect(result).toHaveProperty('errors'); + expect(result.errors).toHaveLength(1); + expect(result.errors?.at(0)?.message).toMatch(/Argument "value" of required type "String!"/); + }); + + it('when the definition has a default and is provided', () => { + const result = executeQueryWithFragmentArguments(` + query { + ...a(value: "A") + } + fragment a($value: String! = "B") on TestType { + fieldWithNonNullableStringInput(input: $value) + } + `); + expectJSON(result).toDeepEqual({ + data: { + fieldWithNonNullableStringInput: '"A"', + }, + }); + }); + + it('when the definition has a default and is not provided', () => { + const result = executeQueryWithFragmentArguments(` + query { + ...a + } + fragment a($value: String! = "B") on TestType { + fieldWithNonNullableStringInput(input: $value) + } + `); + expectJSON(result).toDeepEqual({ + data: { + fieldWithNonNullableStringInput: '"B"', + }, + }); + }); + + it('when a definition has a default, is not provided, and spreads another fragment', () => { + const result = executeQueryWithFragmentArguments(` + query { + ...a + } + fragment a($a: String! = "B") on TestType { + ...b(b: $a) + } + fragment b($b: String!) on TestType { + fieldWithNonNullableStringInput(input: $b) + } + `); + expectJSON(result).toDeepEqual({ + data: { + fieldWithNonNullableStringInput: '"B"', + }, + }); + }); + + it('when the definition has a non-nullable default and is provided null', () => { + const result = executeQueryWithFragmentArguments(` + query { + ...a(value: null) + } + fragment a($value: String! = "B") on TestType { + fieldWithNullableStringInput(input: $value) + } + `); + + expect(result).toHaveProperty('errors'); + expect(result.errors).toHaveLength(1); + expect(result.errors?.at(0)?.message).toMatch(/Argument "value" of non-null type "String!"/); + }); + + it('when the definition has no default and is not provided', () => { + const result = executeQueryWithFragmentArguments(` + query { + ...a + } + fragment a($value: String) on TestType { + fieldWithNonNullableStringInputAndDefaultArgumentValue(input: $value) + } + `); + expectJSON(result).toDeepEqual({ + data: { + fieldWithNonNullableStringInputAndDefaultArgumentValue: '"Hello World"', + }, + }); + }); + + it('when an argument is shadowed by an operation variable', () => { + const result = executeQueryWithFragmentArguments(` + query($x: String! = "A") { + ...a(x: "B") + } + fragment a($x: String) on TestType { + fieldWithNullableStringInput(input: $x) + } + `); + expectJSON(result).toDeepEqual({ + data: { + fieldWithNullableStringInput: '"B"', + }, + }); + }); + + it('when a nullable argument with a field default is not provided and shadowed by an operation variable', () => { + const result = executeQueryWithFragmentArguments(` + query($x: String = "A") { + ...a + } + fragment a($x: String) on TestType { + fieldWithNonNullableStringInputAndDefaultArgumentValue(input: $x) + } + `); + expectJSON(result).toDeepEqual({ + data: { + fieldWithNonNullableStringInputAndDefaultArgumentValue: '"Hello World"', + }, + }); + }); + + it('when a fragment-variable is shadowed by an intermediate fragment-spread but defined in the operation-variables', () => { + const result = executeQueryWithFragmentArguments(` + query($x: String = "A") { + ...a + } + fragment a($x: String) on TestType { + ...b + } + fragment b on TestType { + fieldWithNullableStringInput(input: $x) + } + `); + expectJSON(result).toDeepEqual({ + data: { + fieldWithNullableStringInput: '"A"', + }, + }); + }); + + it('when a fragment is used with different args', () => { + const result = executeQueryWithFragmentArguments(` + query($x: String = "Hello") { + a: nested { + ...a(x: "a") + } + b: nested { + ...a(x: "b", b: true) + } + hello: nested { + ...a(x: $x) + } + } + fragment a($x: String, $b: Boolean = false) on NestedType { + a: echo(input: $x) @skip(if: $b) + b: echo(input: $x) @include(if: $b) + } + `); + expectJSON(result).toDeepEqual({ + data: { + a: { + a: '"a"', + }, + b: { + b: '"b"', + }, + hello: { + a: '"Hello"', + }, + }, + }); + }); + + it('when the argument variable is nested in a complex type', () => { + const result = executeQueryWithFragmentArguments(` + query { + ...a(value: "C") + } + fragment a($value: String) on TestType { + list(input: ["A", "B", $value, "D"]) + } + `); + expectJSON(result).toDeepEqual({ + data: { + list: '["A", "B", "C", "D"]', + }, + }); + }); + + it('when argument variables are used recursively', () => { + const result = executeQueryWithFragmentArguments(` + query { + ...a(aValue: "C") + } + fragment a($aValue: String) on TestType { + ...b(bValue: $aValue) + } + fragment b($bValue: String) on TestType { + list(input: ["A", "B", $bValue, "D"]) + } + `); + expectJSON(result).toDeepEqual({ + data: { + list: '["A", "B", "C", "D"]', + }, + }); + }); + + it('when argument passed in as list', () => { + const result = executeQueryWithFragmentArguments(` + query Q($opValue: String = "op") { + ...a(aValue: "A") + } + fragment a($aValue: String, $bValue: String) on TestType { + ...b(aValue: [$aValue, "B"], bValue: [$bValue, $opValue]) + } + fragment b($aValue: [String], $bValue: [String], $cValue: String) on TestType { + aList: list(input: $aValue) + bList: list(input: $bValue) + cList: list(input: [$cValue]) + } + `); + expectJSON(result).toDeepEqual({ + data: { + aList: '["A", "B"]', + bList: '[null, "op"]', + cList: '[null]', + }, + }); + }); + }); }); diff --git a/packages/executor/src/execution/execute.ts b/packages/executor/src/execution/execute.ts index 01dceb25389..150a6ace242 100644 --- a/packages/executor/src/execution/execute.ts +++ b/packages/executor/src/execution/execute.ts @@ -34,6 +34,7 @@ import { addPath, collectFields, createGraphQLError, + FieldDetails, getArgumentValues, getDefinedRootType, GraphQLStreamDirective, @@ -544,7 +545,7 @@ function executeFieldsSerially( parentType: GraphQLObjectType, sourceValue: unknown, path: Path | undefined, - fields: Map>, + fields: Map>, ): MaybePromise { return promiseReduce( fields, @@ -562,7 +563,6 @@ function executeFieldsSerially( } results[responseName] = result; - return results; }); }, @@ -579,7 +579,7 @@ function executeFields( parentType: GraphQLObjectType, sourceValue: unknown, path: Path | undefined, - fields: Map>, + fields: Map>, asyncPayloadRecord?: AsyncPayloadRecord, ): MaybePromise> { const results = Object.create(null); @@ -639,12 +639,12 @@ function executeField( exeContext: ExecutionContext, parentType: GraphQLObjectType, source: unknown, - fieldNodes: Array, + fieldNodes: Array, path: Path, asyncPayloadRecord?: AsyncPayloadRecord, ): MaybePromise { const errors = asyncPayloadRecord?.errors ?? exeContext.errors; - const fieldDef = getFieldDef(exeContext.schema, parentType, fieldNodes[0]); + const fieldDef = getFieldDef(exeContext.schema, parentType, fieldNodes[0].fieldNode); if (!fieldDef) { return; } @@ -652,14 +652,25 @@ function executeField( const returnType = fieldDef.type; const resolveFn = fieldDef.resolve ?? exeContext.fieldResolver; - const info = buildResolveInfo(exeContext, fieldDef, fieldNodes, parentType, path); + const info = buildResolveInfo( + exeContext, + fieldDef, + fieldNodes.map(details => details.fieldNode), + parentType, + path, + ); // Get the resolve function, regardless of if its result is normal or abrupt (error). try { // Build a JS object of arguments from the field.arguments AST, using the // variables scope to fulfill any variable references. // TODO: find a way to memoize, in case this field is within a List type. - const args = getArgumentValues(fieldDef, fieldNodes[0], exeContext.variableValues); + const args = getArgumentValues( + fieldDef, + fieldNodes[0].fieldNode, + exeContext.variableValues, + fieldNodes[0].fragmentVariableValues, + ); // The resolve function's optional third argument is a context value that // is provided to every resolve function within an execution. It is commonly @@ -689,7 +700,11 @@ function executeField( // Note: we don't rely on a `catch` method, but we do expect "thenable" // to take a second callback for the error case. return completed.then(undefined, rawError => { - const error = locatedError(rawError, fieldNodes, pathToArray(path)); + const error = locatedError( + rawError, + fieldNodes.map(details => details.fieldNode), + pathToArray(path), + ); const handledError = handleFieldError(error, returnType, errors); filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); return handledError; @@ -697,7 +712,11 @@ function executeField( } return completed; } catch (rawError) { - const error = locatedError(rawError, fieldNodes, pathToArray(path)); + const error = locatedError( + rawError, + fieldNodes.map(details => details.fieldNode), + pathToArray(path), + ); const handledError = handleFieldError(error, returnType, errors); filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); return handledError; @@ -772,7 +791,7 @@ function handleFieldError( function completeValue( exeContext: ExecutionContext, returnType: GraphQLOutputType, - fieldNodes: Array, + fieldNodes: Array, info: GraphQLResolveInfo, path: Path, result: unknown, @@ -865,7 +884,7 @@ function completeValue( */ function getStreamValues( exeContext: ExecutionContext, - fieldNodes: Array, + fieldNodes: Array, path: Path, ): | undefined @@ -882,8 +901,8 @@ function getStreamValues( // safe to only check the first fieldNode for the stream directive const stream = getDirectiveValues( GraphQLStreamDirective, - fieldNodes[0], - exeContext.variableValues, + fieldNodes[0].fieldNode, + fieldNodes[0].fragmentVariableValues ?? exeContext.variableValues, ) as { initialCount: number; label: string; @@ -915,7 +934,7 @@ function getStreamValues( async function completeAsyncIteratorValue( exeContext: ExecutionContext, itemType: GraphQLOutputType, - fieldNodes: Array, + fieldNodes: Array, info: GraphQLResolveInfo, path: Path, iterator: AsyncIterator, @@ -954,7 +973,11 @@ async function completeAsyncIteratorValue( break; } } catch (rawError) { - const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); + const error = locatedError( + rawError, + fieldNodes.map(details => details.fieldNode), + pathToArray(itemPath), + ); completedResults.push(handleFieldError(error, itemType, errors)); break; } @@ -986,7 +1009,7 @@ async function completeAsyncIteratorValue( function completeListValue( exeContext: ExecutionContext, returnType: GraphQLList, - fieldNodes: Array, + fieldNodes: Array, info: GraphQLResolveInfo, path: Path, result: unknown, @@ -1077,7 +1100,7 @@ function completeListItemValue( errors: Array, exeContext: ExecutionContext, itemType: GraphQLOutputType, - fieldNodes: Array, + fieldNodes: Array, info: GraphQLResolveInfo, itemPath: Path, asyncPayloadRecord?: AsyncPayloadRecord, @@ -1113,7 +1136,11 @@ function completeListItemValue( // to take a second callback for the error case. completedResults.push( completedItem.then(undefined, rawError => { - const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); + const error = locatedError( + rawError, + fieldNodes.map(details => details.fieldNode), + pathToArray(itemPath), + ); const handledError = handleFieldError(error, itemType, errors); filterSubsequentPayloads(exeContext, itemPath, asyncPayloadRecord); return handledError; @@ -1125,7 +1152,11 @@ function completeListItemValue( completedResults.push(completedItem); } catch (rawError) { - const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); + const error = locatedError( + rawError, + fieldNodes.map(details => details.fieldNode), + pathToArray(itemPath), + ); const handledError = handleFieldError(error, itemType, errors); filterSubsequentPayloads(exeContext, itemPath, asyncPayloadRecord); completedResults.push(handledError); @@ -1170,7 +1201,7 @@ function completeLeafValue(returnType: GraphQLLeafType, result: unknown): unknow function completeAbstractValue( exeContext: ExecutionContext, returnType: GraphQLAbstractType, - fieldNodes: Array, + fieldNodes: Array, info: GraphQLResolveInfo, path: Path, result: unknown, @@ -1216,14 +1247,14 @@ function ensureValidRuntimeType( runtimeTypeName: unknown, exeContext: ExecutionContext, returnType: GraphQLAbstractType, - fieldNodes: Array, + fieldNodes: Array, info: GraphQLResolveInfo, result: unknown, ): GraphQLObjectType { if (runtimeTypeName == null) { throw createGraphQLError( `Abstract type "${returnType.name}" must resolve to an Object type at runtime for field "${info.parentType.name}.${info.fieldName}". Either the "${returnType.name}" type should provide a "resolveType" function or each possible type should provide an "isTypeOf" function.`, - { nodes: fieldNodes }, + { nodes: fieldNodes.map(details => details.fieldNode) }, ); } @@ -1246,21 +1277,21 @@ function ensureValidRuntimeType( if (runtimeType == null) { throw createGraphQLError( `Abstract type "${returnType.name}" was resolved to a type "${runtimeTypeName}" that does not exist inside the schema.`, - { nodes: fieldNodes }, + { nodes: fieldNodes.map(details => details.fieldNode) }, ); } if (!isObjectType(runtimeType)) { throw createGraphQLError( `Abstract type "${returnType.name}" was resolved to a non-object type "${runtimeTypeName}".`, - { nodes: fieldNodes }, + { nodes: fieldNodes.map(details => details.fieldNode) }, ); } if (!exeContext.schema.isSubType(returnType, runtimeType)) { throw createGraphQLError( `Runtime Object type "${runtimeType.name}" is not a possible type for "${returnType.name}".`, - { nodes: fieldNodes }, + { nodes: fieldNodes.map(details => details.fieldNode) }, ); } @@ -1273,7 +1304,7 @@ function ensureValidRuntimeType( function completeObjectValue( exeContext: ExecutionContext, returnType: GraphQLObjectType, - fieldNodes: Array, + fieldNodes: Array, info: GraphQLResolveInfo, path: Path, result: unknown, @@ -1319,12 +1350,12 @@ function completeObjectValue( function invalidReturnTypeError( returnType: GraphQLObjectType, result: unknown, - fieldNodes: Array, + fieldNodes: Array, ): GraphQLError { return createGraphQLError( `Expected value of type "${returnType.name}" but got: ${inspect(result)}.`, { - nodes: fieldNodes, + nodes: fieldNodes.map(details => details.fieldNode), }, ); } @@ -1332,7 +1363,7 @@ function invalidReturnTypeError( function collectAndExecuteSubfields( exeContext: ExecutionContext, returnType: GraphQLObjectType, - fieldNodes: Array, + fieldNodes: Array, path: Path, result: unknown, asyncPayloadRecord?: AsyncPayloadRecord, @@ -1341,7 +1372,7 @@ function collectAndExecuteSubfields( const { fields: subFieldNodes, patches: subPatches } = collectSubfields( exeContext, returnType, - fieldNodes, + fieldNodes.map(details => details.fieldNode), ); const subFields = executeFields( @@ -1648,17 +1679,23 @@ function executeSubscription(exeContext: ExecutionContext): MaybePromise details.fieldNode), }); } const path = addPath(undefined, responseName, rootType.name); - const info = buildResolveInfo(exeContext, fieldDef, fieldNodes, rootType, path); + const info = buildResolveInfo( + exeContext, + fieldDef, + fieldNodes.map(details => details.fieldNode), + rootType, + path, + ); try { // Implements the "ResolveFieldEventStream" algorithm from GraphQL specification. @@ -1666,7 +1703,12 @@ function executeSubscription(exeContext: ExecutionContext): MaybePromise { - throw locatedError(error, fieldNodes, pathToArray(path)); + throw locatedError( + error, + fieldNodes.map(details => details.fieldNode), + pathToArray(path), + ); }); } return assertEventStream(result, exeContext.signal); } catch (error) { - throw locatedError(error, fieldNodes, pathToArray(path)); + throw locatedError( + error, + fieldNodes.map(details => details.fieldNode), + pathToArray(path), + ); } } @@ -1716,7 +1766,7 @@ function executeDeferredFragment( exeContext: ExecutionContext, parentType: GraphQLObjectType, sourceValue: unknown, - fields: Map>, + fields: Map>, label?: string, path?: Path, parentContext?: AsyncPayloadRecord, @@ -1756,7 +1806,7 @@ function executeStreamField( itemPath: Path, item: MaybePromise, exeContext: ExecutionContext, - fieldNodes: Array, + fieldNodes: Array, info: GraphQLResolveInfo, itemType: GraphQLOutputType, label?: string, @@ -1799,14 +1849,22 @@ function executeStreamField( // Note: we don't rely on a `catch` method, but we do expect "thenable" // to take a second callback for the error case. completedItem = completedItem.then(undefined, rawError => { - const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); + const error = locatedError( + rawError, + fieldNodes.map(details => details.fieldNode), + pathToArray(itemPath), + ); const handledError = handleFieldError(error, itemType, asyncPayloadRecord.errors); filterSubsequentPayloads(exeContext, itemPath, asyncPayloadRecord); return handledError; }); } } catch (rawError) { - const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); + const error = locatedError( + rawError, + fieldNodes.map(details => details.fieldNode), + pathToArray(itemPath), + ); completedItem = handleFieldError(error, itemType, asyncPayloadRecord.errors); filterSubsequentPayloads(exeContext, itemPath, asyncPayloadRecord); } @@ -1838,7 +1896,7 @@ function executeStreamField( async function executeStreamIteratorItem( iterator: AsyncIterator, exeContext: ExecutionContext, - fieldNodes: Array, + fieldNodes: Array, info: GraphQLResolveInfo, itemType: GraphQLOutputType, asyncPayloadRecord: StreamRecord, @@ -1853,7 +1911,11 @@ async function executeStreamIteratorItem( } item = value; } catch (rawError) { - const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); + const error = locatedError( + rawError, + fieldNodes.map(details => details.fieldNode), + pathToArray(itemPath), + ); const value = handleFieldError(error, itemType, asyncPayloadRecord.errors); // don't continue if iterator throws return { done: true, value }; @@ -1872,7 +1934,11 @@ async function executeStreamIteratorItem( if (isPromise(completedItem)) { completedItem = completedItem.then(undefined, rawError => { - const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); + const error = locatedError( + rawError, + fieldNodes.map(details => details.fieldNode), + pathToArray(itemPath), + ); const handledError = handleFieldError(error, itemType, asyncPayloadRecord.errors); filterSubsequentPayloads(exeContext, itemPath, asyncPayloadRecord); return handledError; @@ -1880,7 +1946,11 @@ async function executeStreamIteratorItem( } return { done: false, value: completedItem }; } catch (rawError) { - const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); + const error = locatedError( + rawError, + fieldNodes.map(details => details.fieldNode), + pathToArray(itemPath), + ); const value = handleFieldError(error, itemType, asyncPayloadRecord.errors); filterSubsequentPayloads(exeContext, itemPath, asyncPayloadRecord); return { done: false, value }; @@ -1891,7 +1961,7 @@ async function executeStreamIterator( initialIndex: number, iterator: AsyncIterator, exeContext: ExecutionContext, - fieldNodes: Array, + fieldNodes: Array, info: GraphQLResolveInfo, itemType: GraphQLOutputType, path: Path, diff --git a/packages/executor/src/execution/values.ts b/packages/executor/src/execution/values.ts index 3eef1080653..4ec83ce3626 100644 --- a/packages/executor/src/execution/values.ts +++ b/packages/executor/src/execution/values.ts @@ -1,13 +1,17 @@ import { + ArgumentNode, coerceInputValue, + FragmentSpreadNode, GraphQLError, GraphQLSchema, isInputType, isNonNullType, + Kind, NamedTypeNode, print, typeFromAST, valueFromAST, + valueFromASTUntyped, VariableDefinitionNode, } from 'graphql'; import { createGraphQLError, hasOwnProperty, inspect, printPathArray } from '@graphql-tools/utils'; @@ -124,3 +128,66 @@ function coerceVariableValues( return coercedValues; } + +export function getArgumentValuesFromSpread( + /** NOTE: For error annotations only */ + node: FragmentSpreadNode & { arguments?: ArgumentNode[] }, + schema: GraphQLSchema, + fragmentVarDefs: ReadonlyArray, + variableValues: { [variable: string]: unknown }, + fragmentArgValues?: { [variable: string]: unknown }, +): { [argument: string]: unknown } { + const coercedValues: { [argument: string]: unknown } = {}; + const argNodeMap = new Map(node.arguments?.map(arg => [arg.name.value, arg])); + + for (const varDef of fragmentVarDefs) { + const name = varDef.variable.name.value; + const argType = typeFromAST(schema, varDef.type); + const argumentNode = argNodeMap.get(name); + + if (argumentNode == null) { + if (varDef.defaultValue !== undefined) { + coercedValues[name] = valueFromASTUntyped(varDef.defaultValue); + } else if (isNonNullType(argType)) { + throw new GraphQLError( + `Argument "${name}" of required type "${inspect(argType)}" ` + 'was not provided.', + { nodes: node }, + ); + } else { + coercedValues[name] = undefined; + } + continue; + } + + const valueNode = argumentNode.value; + + let hasValue = valueNode.kind !== Kind.NULL; + if (valueNode.kind === Kind.VARIABLE) { + const variableName = valueNode.name.value; + if (fragmentArgValues != null && Object.hasOwn(fragmentArgValues, variableName)) { + hasValue = fragmentArgValues[variableName] != null; + } else if (variableValues != null && Object.hasOwn(variableValues, variableName)) { + hasValue = variableValues[variableName] != null; + } + } + + if (!hasValue && isNonNullType(argType)) { + throw new GraphQLError( + `Argument "${name}" of non-null type "${inspect(argType)}" ` + 'must not be null.', + { nodes: valueNode }, + ); + } + + // TODO: Make this follow the spec more closely + let coercedValue; + if (argType && isInputType(argType)) { + coercedValue = valueFromAST(valueNode, argType, { + ...variableValues, + ...fragmentArgValues, + }); + } + + coercedValues[name] = coercedValue; + } + return coercedValues; +} diff --git a/packages/stitch/src/executor.ts b/packages/stitch/src/executor.ts index ee4b19dd193..6f9b50c3b03 100644 --- a/packages/stitch/src/executor.ts +++ b/packages/stitch/src/executor.ts @@ -32,8 +32,8 @@ export function createStitchingExecutor(stitchedSchema: GraphQLSchema) { operation.selectionSet, ); const data: Record = {}; - for (const [fieldName, fieldNodes] of fields) { - const responseKey = fieldNodes[0].alias?.value ?? fieldName; + for (const [fieldName, fieldGroups] of fields) { + const responseKey = fieldGroups[0].fieldNode.alias?.value ?? fieldName; const subschemaForField = subschemas.find(subschema => { const subschemaSchema = isSubschemaConfig(subschema) ? subschema.schema @@ -48,7 +48,7 @@ export function createStitchingExecutor(stitchedSchema: GraphQLSchema) { info: { schema: stitchedSchema, fieldName, - fieldNodes, + fieldNodes: fieldGroups.map(group => group.fieldNode), operation, fragments, parentType: rootType, diff --git a/packages/stitch/src/getFieldsNotInSubschema.ts b/packages/stitch/src/getFieldsNotInSubschema.ts index f75bc80f968..0c32bb4305e 100644 --- a/packages/stitch/src/getFieldsNotInSubschema.ts +++ b/packages/stitch/src/getFieldsNotInSubschema.ts @@ -25,11 +25,11 @@ export function getFieldsNotInSubschema( const fields = subschemaType.getFields(); const fieldsNotInSchema = new Set(); - for (const [, subFieldNodes] of subFieldNodesByResponseKey) { - const fieldName = subFieldNodes[0].name.value; + for (const [, subFieldGroups] of subFieldNodesByResponseKey) { + const fieldName = subFieldGroups[0].fieldNode.name.value; if (!fields[fieldName]) { - for (const subFieldNode of subFieldNodes) { - fieldsNotInSchema.add(subFieldNode); + for (const subFieldNode of subFieldGroups) { + fieldsNotInSchema.add(subFieldNode.fieldNode); } } const fieldNodesForField = fieldNodesByField?.[gatewayType.name]?.[fieldName]; diff --git a/packages/stitch/src/stitchingInfo.ts b/packages/stitch/src/stitchingInfo.ts index 72ca8c9f3a3..cbba1ed4a9c 100644 --- a/packages/stitch/src/stitchingInfo.ts +++ b/packages/stitch/src/stitchingInfo.ts @@ -318,7 +318,7 @@ export function completeStitchingInfo>( const { fields } = collectFields(schema, fragments, variableValues, type, selectionSet); for (const [, fieldNodes] of fields) { - for (const fieldNode of fieldNodes) { + for (const { fieldNode } of fieldNodes) { const key = print(fieldNode); if (fieldNodeMap[key] == null) { fieldNodeMap[key] = fieldNode; diff --git a/packages/utils/src/collectFields.ts b/packages/utils/src/collectFields.ts index 8646f86040a..af8e8de6403 100644 --- a/packages/utils/src/collectFields.ts +++ b/packages/utils/src/collectFields.ts @@ -13,51 +13,60 @@ import { SelectionSetNode, typeFromAST, } from 'graphql'; +import { getArgumentValuesFromSpread } from '@graphql-tools/executor'; import { AccumulatorMap } from './AccumulatorMap.js'; import { GraphQLDeferDirective } from './directives.js'; import { memoize5 } from './memoize.js'; export interface PatchFields { label: string | undefined; - fields: Map>; + fields: Map>; } export interface FieldsAndPatches { - fields: Map>; + fields: Map>; patches: Array; } +export interface FieldDetails { + fieldNode: FieldNode; + fragmentVariableValues?: { [key: string]: unknown } | undefined; +} + function collectFieldsImpl( schema: GraphQLSchema, fragments: Record, variableValues: TVariables, runtimeType: GraphQLObjectType, selectionSet: SelectionSetNode, - fields: AccumulatorMap, + fields: AccumulatorMap, patches: Array, visitedFragmentNames: Set, + localVariableValues: { [variable: string]: unknown } | undefined, ): void { for (const selection of selectionSet.selections) { switch (selection.kind) { case Kind.FIELD: { - if (!shouldIncludeNode(variableValues, selection)) { + const vars = localVariableValues ?? variableValues; + if (!shouldIncludeNode(vars, selection)) { continue; } - fields.add(getFieldEntryKey(selection), selection); + fields.add(getFieldEntryKey(selection), { fieldNode: selection }); break; } case Kind.INLINE_FRAGMENT: { + const vars = localVariableValues ?? variableValues; if ( - !shouldIncludeNode(variableValues, selection) || + !shouldIncludeNode(vars, selection) || !doesFragmentConditionMatch(schema, selection, runtimeType) ) { continue; } - const defer = getDeferValues(variableValues, selection); + const defer = getDeferValues(vars, selection); if (defer) { - const patchFields = new AccumulatorMap(); + const patchFields = new AccumulatorMap(); collectFieldsImpl( schema, fragments, @@ -67,6 +76,7 @@ function collectFieldsImpl( patchFields, patches, visitedFragmentNames, + localVariableValues, ); patches.push({ label: defer.label, @@ -82,18 +92,20 @@ function collectFieldsImpl( fields, patches, visitedFragmentNames, + localVariableValues, ); } break; } case Kind.FRAGMENT_SPREAD: { + const vars = localVariableValues ?? variableValues; const fragName = selection.name.value; - if (!shouldIncludeNode(variableValues, selection)) { + if (!shouldIncludeNode(vars, selection)) { continue; } - const defer = getDeferValues(variableValues, selection); + const defer = getDeferValues(vars, selection); if (visitedFragmentNames.has(fragName) && !defer) { continue; } @@ -107,8 +119,18 @@ function collectFieldsImpl( visitedFragmentNames.add(fragName); } + const spreadVariableValues = fragment.variableDefinitions + ? getArgumentValuesFromSpread( + selection, + schema, + fragment.variableDefinitions, + variableValues as any, + localVariableValues, + ) + : undefined; + if (defer) { - const patchFields = new AccumulatorMap(); + const patchFields = new AccumulatorMap(); collectFieldsImpl( schema, fragments, @@ -118,6 +140,7 @@ function collectFieldsImpl( patchFields, patches, visitedFragmentNames, + spreadVariableValues, ); patches.push({ label: defer.label, @@ -133,6 +156,7 @@ function collectFieldsImpl( fields, patches, visitedFragmentNames, + spreadVariableValues, ); } break; @@ -156,7 +180,7 @@ export function collectFields( runtimeType: GraphQLObjectType, selectionSet: SelectionSetNode, ): FieldsAndPatches { - const fields = new AccumulatorMap(); + const fields = new AccumulatorMap(); const patches: Array = []; collectFieldsImpl( schema, @@ -167,6 +191,7 @@ export function collectFields( fields, patches, new Set(), + undefined, ); return { fields, patches }; } @@ -261,7 +286,7 @@ export const collectSubFields = memoize5(function collectSubfields( returnType: GraphQLObjectType, fieldNodes: Array, ): FieldsAndPatches { - const subFieldNodes = new AccumulatorMap(); + const subFieldNodes = new AccumulatorMap(); const visitedFragmentNames = new Set(); const subPatches: Array = []; @@ -281,6 +306,7 @@ export const collectSubFields = memoize5(function collectSubfields( subFieldNodes, subPatches, visitedFragmentNames, + undefined, ); } } diff --git a/packages/utils/src/getArgumentValues.ts b/packages/utils/src/getArgumentValues.ts index 44ddfad57aa..eac91271f11 100644 --- a/packages/utils/src/getArgumentValues.ts +++ b/packages/utils/src/getArgumentValues.ts @@ -25,6 +25,7 @@ export function getArgumentValues( def: GraphQLField | GraphQLDirective, node: FieldNode | DirectiveNode, variableValues: Record = {}, + fragmentArgValues?: Record, ): Record { const coercedValues = {}; @@ -59,21 +60,30 @@ export function getArgumentValues( if (valueNode.kind === Kind.VARIABLE) { const variableName = valueNode.name.value; - if (variableValues == null || !hasOwnProperty(variableValues, variableName)) { - if (defaultValue !== undefined) { + if (fragmentArgValues != null && hasOwnProperty(fragmentArgValues, variableName)) { + isNull = fragmentArgValues[variableName] == null; + if (isNull && defaultValue !== undefined) { coercedValues[name] = defaultValue; - } else if (isNonNullType(argType)) { - throw createGraphQLError( - `Argument "${name}" of required type "${inspect(argType)}" ` + - `was provided the variable "$${variableName}" which was not provided a runtime value.`, - { - nodes: [valueNode], - }, - ); + continue; } + } else if (variableValues != null && hasOwnProperty(variableValues, variableName)) { + isNull = variableValues[variableName] == null; + if (isNull && defaultValue !== undefined) { + coercedValues[name] = defaultValue; + continue; + } + } else if (defaultValue !== undefined) { + coercedValues[name] = defaultValue; + continue; + } else if (isNonNullType(argType)) { + throw createGraphQLError( + `Argument "${name}" of required type "${inspect(argType)}" ` + + `was provided the variable "$${variableName}" which was not provided a runtime value.`, + { nodes: valueNode }, + ); + } else { continue; } - isNull = variableValues[variableName] == null; } if (isNull && isNonNullType(argType)) { @@ -85,7 +95,10 @@ export function getArgumentValues( ); } - const coercedValue = valueFromAST(valueNode, argType, variableValues); + const coercedValue = valueFromAST(valueNode, argType, { + ...variableValues, + ...fragmentArgValues, + }); if (coercedValue === undefined) { // Note: ValuesOfCorrectTypeRule validation should catch this before // execution. This is a runtime check to ensure execution does not diff --git a/packages/utils/src/visitResult.ts b/packages/utils/src/visitResult.ts index 906b701589c..974568344a5 100644 --- a/packages/utils/src/visitResult.ts +++ b/packages/utils/src/visitResult.ts @@ -1,5 +1,4 @@ import { - FieldNode, FragmentDefinitionNode, getNullableType, GraphQLError, @@ -15,7 +14,7 @@ import { TypeMetaFieldDef, TypeNameMetaFieldDef, } from 'graphql'; -import { collectFields, collectSubFields } from './collectFields.js'; +import { collectFields, collectSubFields, FieldDetails } from './collectFields.js'; import { getOperationASTFromRequest } from './getOperationASTFromRequest.js'; import { ExecutionRequest, ExecutionResult } from './Interfaces.js'; import { Maybe } from './types.js'; @@ -204,7 +203,7 @@ function visitRoot( function visitObjectValue( object: Record, type: GraphQLObjectType, - fieldNodeMap: Map, + fieldNodeMap: Map, schema: GraphQLSchema, fragments: Record, variableValues: Record, @@ -230,7 +229,7 @@ function visitObjectValue( } for (const [responseKey, subFieldNodes] of fieldNodeMap) { - const fieldName = subFieldNodes[0].name.value; + const fieldName = subFieldNodes[0].fieldNode.name.value; let fieldType = fieldMap[fieldName]?.type; if (fieldType == null) { switch (fieldName) { @@ -257,6 +256,7 @@ function visitObjectValue( addPathSegmentInfo(type, fieldName, newPathIndex, fieldErrors, errorInfo); } + // TODO: for fragment arguments we might need to update the variable-values here. const newValue = visitFieldValue( object[responseKey], fieldType, @@ -322,7 +322,7 @@ function updateObject( function visitListValue( list: Array, returnType: GraphQLOutputType, - fieldNodes: Array, + fieldNodes: Array, schema: GraphQLSchema, fragments: Record, variableValues: Record, @@ -350,7 +350,7 @@ function visitListValue( function visitFieldValue( value: any, returnType: GraphQLOutputType, - fieldNodes: Array, + fieldGroups: Array, schema: GraphQLSchema, fragments: Record, variableValues: Record, @@ -368,7 +368,7 @@ function visitFieldValue( return visitListValue( value as Array, nullableType.ofType, - fieldNodes, + fieldGroups, schema, fragments, variableValues, @@ -384,7 +384,7 @@ function visitFieldValue( fragments, variableValues, finalType, - fieldNodes, + fieldGroups.map(group => group.fieldNode), ); return visitObjectValue( value, @@ -404,7 +404,7 @@ function visitFieldValue( fragments, variableValues, nullableType, - fieldNodes, + fieldGroups.map(group => group.fieldNode), ); return visitObjectValue( value,